Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using @turbo with FowardDiff.Dual for logsumexp #437

Open
magerton opened this issue Sep 30, 2022 · 6 comments
Open

Using @turbo with FowardDiff.Dual for logsumexp #437

magerton opened this issue Sep 30, 2022 · 6 comments

Comments

@magerton
Copy link

Using @turbo loops gives incredible performance gains (10x) over the LogExpFunctions library for arrays of Float64s. However, the @turbo doesn't seem to play well with FowardDiff.Dual arrays and prints the warning below. Is there a way to leverage LoopVectorization to accelerate operations on Dual numbers?

`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.

I'm uploading a Pluto notebook with some benchmarks, which I reproduce below

Not sure if this is related to #93. @chriselrod , I think that this is related to your posts at https://discourse.julialang.org/t/speeding-up-my-logsumexp-function/42380/9?page=2 and https://discourse.julialang.org/t/fast-logsumexp-over-4th-dimension/64182/26

Thanks!

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "Float64" => 6-element BenchmarkTools.BenchmarkGroup:
	  tags: ["Float64"]
	  "Vanilla Loop" => Trial(29.500 μs)
	  "Tullio" => Trial(5.200 μs)
	  "LogExpFunctions" => Trial(35.700 μs)
	  "Turbo" => Trial(3.000 μs)
	  "SIMD Loop" => Trial(25.500 μs)
	  "Vmap" => Trial(3.800 μs)
  "Dual" => 6-element BenchmarkTools.BenchmarkGroup:
	  tags: ["Dual"]
	  "Vanilla Loop" => Trial(45.300 μs)
	  "Tullio" => Trial(53.100 μs)
	  "LogExpFunctions" => Trial(62.800 μs)
	  "Turbo" => Trial(311.900 μs)
	  "SIMD Loop" => Trial(37.600 μs)
	  "Vmap" => Trial(44.300 μs)

LoopVectorization functions are

"""
using `LoopVectorization.@turbo` loops

**NOTE** - not compatible with `ForwardDiff.Dual` numbers!
"""
function logsumexp_turbo!(Vbar, tmp_max, X)
	n,k = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
	@turbo for i in 1:n, j in 1:k
		Vbar[i] += exp(X[i,j] - tmp_max[i])
	end
	@turbo for i in 1:n
		Vbar[i] = log(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

"""
using `LoopVectorization` `vmap` convenience fcts

**NOTE** - this DOES work with `ForwardDiff.Dual` numbers!
"""
function logsumexp_vmap!(Vbar, tmp_max, X, Xtmp)
	maximum!(tmp_max, X)
	n = size(X,2)
	for j in 1:n
		Xtmpj = view(Xtmp, :, j)
		Xj    = view(X, :, j)
		vmap!((xij, mi) -> exp(xij-mi), Xtmpj, Xj, tmp_max)
	end
	Vbartmp = vreduce(+, Xtmp; dims=2)
	vmap!((vi,mi) -> log(vi) + mi, Vbar, Vbartmp, tmp_max)
	return Vbar
end
@magerton
Copy link
Author

See notebook logsumexp-speedtests.pdf

@magerton
Copy link
Author

magerton commented Sep 30, 2022

HTML rendering of notebook (strip off .txt) logsumexp-speedtests.jl.html.txt

Pluto notebook (strip off .txt) logsumexp-speedtests.jl.txt

@magerton
Copy link
Author

magerton commented Sep 30, 2022

I was able to get a bit faster for Dual numbers by pirating vexp and log_fast, though the relative speedup (2x) is still less than what @turbo does for Float64 arrays.

using ForwardDiff
const FD = ForwardDiff
import VectorizationBase: vexp
import SLEEFPirates: log_fast

@inline function vexp(d::FD.Dual{T}) where {T}
    val = vexp(FD.value(d))
    partials =  FD.partials(d)
    return FD.Dual{T}(val, val * partials)
end

@inline function log_fast(d::FD.Dual{T}) where {T}
    val = FD.value(d)
    partials =  FD.partials(d)
    return FD.Dual{T}(log_fast(val), inv(val) * partials)
end

"using base SIMD loops with LoopVectorization tricks"
function logsumexp_tricks!(Vbar, tmp_max, X)
	m,n = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
	@inbounds for j in 1:n
		@simd for i in 1:m
			Vbar[i] += vexp(X[i,j] - tmp_max[i])
		end
	end
    	
	@inbounds @simd for i in 1:m
		Vbar[i] = log_fast(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

@magerton
Copy link
Author

magerton commented Sep 30, 2022

This also worked, though wasn't quite as fast

"using base SIMD loops with LoopVectorization tricks"
function logsumexp_turbo2!(Vbar, tmp_max, X)
	m,n = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
    @turbo safe=false warn_check_args=false for i in 1:m, j in 1:n
		Vbar[i] += vexp(X[i,j] - tmp_max[i])
	end
    	
	@turbo safe=false warn_check_args=false for i in 1:m
		Vbar[i] = log_fast(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

@chriselrod
Copy link
Member

chriselrod commented Sep 30, 2022

I may respond with more later, but you can get some ideas for more tricks here:
https://github.com/PumasAI/SimpleChains.jl/blob/main/src/forwarddiff_matmul.jl

Long term, the rewrite should "just work" for duals/generic Julia code. However, it is currently a long ways away; I'm still working on the rewrite's dependence analysis (which LoopVectorization.jl of course doesn't do at all).

@magerton
Copy link
Author

magerton commented Oct 1, 2022

Thank you for the quick response, @chriselrod -- really appreciate it. I had a bit of a hard time understanding the code you referenced, but it looked to me that the strategy was to reinterpret(reshape, T, A) the arrays as Float64 arrays and do the derivative computations separately. Is that the strategy you're suggesting? I tried that strategy and managed to get a big speedup vs LogExpFunctions for the case of logsumexp(X::AbstractVector{<:ForwardDiff.Dual}) over LogExpFunctions

See FastLogSumExp.jl. Benchmarks for these and a few more are at https://github.com/magerton/FastLogSumExp.jl. Benchmarking is done in runtests.jl. Using @turbo and reinterpreting arrays gives ~5-6x speedup on my current machine, and on my other one was giving 10x speedups.

Vector case

"fastest logsumexp over Dual vector requires tmp vector"
function vec_logsumexp_dual_reinterp!(tmp::AbstractVector{V}, X::AbstractVector{<:FD.Dual{T,V,K}}) where {T,V,K}
    Xre   = reinterpret(reshape, V, X)

    uv = typemin(V)
    @turbo for i in eachindex(X)
        uv = max(uv, Xre[1,i])
    end

    s = zero(V)

    @turbo for j in eachindex(X,tmp)
        ex = exp(Xre[1,j] - uv)
        tmp[j] = ex
        s += ex
    end

    v = log(s) + uv # logsumexp value

    invs = inv(s) # for doing softmax for derivatives

    # would be nice to use a more elegant consruction for
    # pvec instead of multiple conversions below
    # that said, it seems like we still have zero allocations
    pvec = zeros(MVector{K,V})
    @turbo for j in eachindex(X,tmp)
        tmp[j] *= invs
        for k in 1:K
            pvec[k] += tmp[j]*Xre[k+1,j]
        end
    end

    ptup = NTuple{K,V}(pvec)
    ptl = FD.Partials{K,V}(ptup)

    return FD.Dual{T,V,K}(v, ptl)

end

Matrix case for logsumexp(X; dims=2)

function mat_logsumexp_dual_reinterp!(
    Vbar::AbstractVector{D}, tmp_max::AbstractVector{V}, 
    tmpX::Matrix{V}, X::AbstractMatrix{D}
    ) where {T,V,K,D<:FD.Dual{T,V,K}}
    
    m,n = size(X)

    (m,n) == size(tmpX) || throw(DimensionMismatch())
    (m,) == size(Vbar) == size(tmp_max) || throw(DimensionMismatch())

    Vre   = reinterpret(reshape, V, Vbar)
    Xre   = reinterpret(reshape, V, X)

    tmp_inv = tmp_max # resuse

    fill!(Vbar, 0)
    fill!(tmp_max, typemin(V))

    @turbo for i in 1:m, j in 1:n
        tmp_max[i] = max(tmp_max[i], Xre[1,i,j])
    end

    @turbo for i in 1:m, j in 1:n
        ex = exp(Xre[1,i,j] - tmp_max[i])
        tmpX[i,j] = ex
        Vre[1,i] += ex
    end

    @turbo for i in 1:m
        v = Vre[1,i]
        m = tmp_max[i]
        tmp_inv[i] = inv(v)
        Vre[1,i] = log(v) + m
    end

    @turbo for i in 1:m, j in 1:n, k in 1:K
        Vre[k+1,i] += tmpX[i,j]*Xre[k+1,i,j]*tmp_inv[i]
    end

    return Vbar

end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants