Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Matrix{BlasFloat} * or \ VecOrMatrix{Dual} #589

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

alyst
Copy link

@alyst alyst commented Jul 18, 2022

The calculation of the gradient from a multivariate normal prior involves left-division by a triangular matrix.
Currently, when the covariance matrix is fixed, and the random vector depends on model parameters (the typical use case), it is done via a fallback path in LinearAlgebra: the constant matrix gets promoted to Dual type (i.e. it is copied each time the gradient is calculated), and then the generic triangular left division implementation is called. For 100x100 and larger matrices this results in big CPU and memory overhead.

However, when the matrix is constant, we don't need to convert it to Dual. We just have to left divide the dual values vector as well as each partial by this matrix. Since it would be the operation on fixed vectors, LAPACK's trtrs() could be used for much faster division.
This is what this PR does. To avoid excessive copying, it relies on a hack: the array of duals could be accessed as a vector of N+1 floats.

src/dual.jl Outdated Show resolved Hide resolved
src/dual.jl Outdated
Comment on lines 816 to 820
Base.:*(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
UpperTriangular{<:LinearAlgebra.BlasFloat},
StridedMatrix{<:LinearAlgebra.BlasFloat}},
x::StridedVecOrMat{<:Dual}) =
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> lmul!(m, x), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's going to be very hard to avoid ambiguities here.

julia> rand(3,3) * Dual.(rand(3),0)
ERROR: MethodError: *(::Matrix{Float64}, ::Vector{Dual{Nothing, Float64, 1}}) is ambiguous.

Candidates:
  *(A::StridedMatrix{T}, x::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real}
    @ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:49
  *(m::Union{LowerTriangular{var"#s87", S} where {var"#s87"<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:AbstractMatrix{var"#s87"}}, UpperTriangular{var"#s86", S} where {var"#s86"<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:AbstractMatrix{var"#s86"}}, StridedMatrix{<:Union{Float32, Float64, ComplexF32, ComplexF64}}}, x::StridedVecOrMat{<:Dual})
    @ Main REPL[51]:4

Attaching the rule to mul!(Matrix{<:Dual}, ...) seems less likely to trigger them.

Testing with Test.detect_ambiguities(ForwardDiff) might be a good idea too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attaching the rule to mul!(Matrix{<:Dual}, ...) seems less likely to trigger them.

The problem is that promotion to Dual already happens at \ (probably for * too).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I think LinearAlgebra does something like this trick for some mixed real/complex cases (where the strides work out correctly). Maybe mul!(ones(4,4).+im, rand(ComplexF64, 4,4), rand(4,4), true, false) is one? Staying close to the signature used there is probably a way to avoid ambiguities.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\ might be easier than *, as it doesn't have so many methods.

For *, promotion to make C should happen correctly without this package doing anything, I think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For *, promotion to make C should happen correctly without this package doing anything, I think.

This is *(m::Triangular, x::Vector)

@codecov-commenter
Copy link

codecov-commenter commented Jul 19, 2022

Codecov Report

Patch coverage: 93.02% and project coverage change: -2.55 ⚠️

Comparison is base (e3670ce) 89.65% compared to head (1d38eac) 87.11%.

❗ Current head 1d38eac differs from pull request most recent head 22d8670. Consider uploading reports for the commit 22d8670 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #589      +/-   ##
==========================================
- Coverage   89.65%   87.11%   -2.55%     
==========================================
  Files          11        9       -2     
  Lines         967      947      -20     
==========================================
- Hits          867      825      -42     
- Misses        100      122      +22     
Impacted Files Coverage Δ
src/dual.jl 81.81% <93.02%> (-0.34%) ⬇️

... and 9 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@alyst
Copy link
Author

alyst commented Jul 20, 2022

Now these functions should be fixed and properly tested.
The latter would require merging JuliaDiff/DiffTests.jl#11 first and updating the deps here.

@alyst alyst changed the title Faster * and / of dual array by constant matrix Faster Matrix{BlasFloat} * or \ of VecOrMatrix{Dual} Jul 20, 2022
@alyst alyst changed the title Faster Matrix{BlasFloat} * or \ of VecOrMatrix{Dual} Faster Matrix{BlasFloat} * or \ VecOrMatrix{Dual} Jul 20, 2022
src/dual.jl Outdated Show resolved Hide resolved
@alyst
Copy link
Author

alyst commented Sep 22, 2022

@mcabbott @fredrikekre @devmotion I've refactored the code so that in-place ldiv/mul are also supported. These changes should now be covered by tests with the updated DiffTests package. There are some linalg differences between 1.0 and the current 1.x, so some of the tests are disabled on 1.0, I guess it's the most straightforward way to handle incompatibilities.

src/dual.jl Outdated
Comment on lines 871 to 874
@eval LinearAlgebra.mul!(y::StridedMatrix{T}, m::$MT, x::StridedMatrix{T}) where T <: Dual =
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x), y, x)

@eval Base.:*(m::$MT, x::StridedMatrix{<:Dual}) = mul!(similar(x, (size(m, 1), size(x, 2))), m, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've lost track a bit, but since I was passing by:

  • Why does this case call _map_dual_components! and not just reinterpret y?
  • Why add a method to * here, won't it go to mul! anyway?
  • Should there be any methods at all for 3-arg mul!, or can they all be on 5-arg mul! only, as that's eventually called?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott Unfortunately, I also start to loose track, but AFAIR I was exploring these possibilities:

  • reinterpret() would only work for the dual vectors (translating it into normal matrix multiplication), but for dual matrices there's no representation as a normal matrix linalg operation
  • * and \ methods are required to overload Base Julia methods that would promote eltypes of all vectors and matrices to Dual, so we need to intercept that code path early
  • mul!(): AFAIR there are some 3-arg methods that don't call 5-arg methods, plus it was evolving from Julia 1.0 to 1.8, so it is hard to come up with the set of methods that are optimal for all the versions. This is a part of the PR that could be potentially more polished, but as the whole infrastructure of mul!()/ldiv!() methods in LinearAlgrebra is nontrivial, I was waiting for the ForwardDiff devs feedback and approval of the PR in principle before going forward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we'll get there!

I do think you can reinterpret dual-mat * matorvec. reinterpret(Float64, rand(3,4).+im) * rand(4, 5) makes a 6×5 Matrix{Float64} via existing methods, I think you're suggesting that there is no method which makes a 2×3×5 array, but this is just reshape. I think that @less mul!((rand(3).+im), (rand(3,3).+im), rand(3), true, false) works like this, without ever calling reinterpret but just passing pointers.

Agree that some of the other methods here need to catch * or / directly.

Re 1.0, the only reason not to drop it is not being sure what version of ForwardDiff to call that... I mean it shouldn't break on 1.0, but it's totally fine for 1.0 to miss these fast paths.

Copy link
Author

@alyst alyst Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott BLAS has complex matrix-matrix multiplication (zgemm()), so there's need for reinterpretation. The K×L matrix z of duals with N partials could be represented as (N+1)×L×M array Z of reals. When multiplied by the constant K×L matrix A, the result y = A×z could be reinterpreted as (N+1)×K×M array. My point is that K and L are the middle dimension of Y and Z arrays, resp. So any combination of reshaping and transposing operations would not make these dimensions the first or the last one (so that the resulting matrix is compatible with matrix A multiplication). One would need permutedims, which involves array elements reshuffling.

@alyst
Copy link
Author

alyst commented Jun 22, 2023

I've rebased the PR and cleaned up commit history a bit. For tests to succeed, DiffTests 0.1.3 is required (JuliaDiff/DiffTests.jl#17).
There were concerns regarding how efficient is constant_matrix * dual_matrix case, and whether it could be reinterpreted as normal numeric matrix multiplication, but I think plain reinterpretation is not possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants