Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Port sensitivity definitions to ChainRules #177

Open
ararslan opened this issue Jun 18, 2019 · 0 comments
Open

Port sensitivity definitions to ChainRules #177

ararslan opened this issue Jun 18, 2019 · 0 comments

Comments

@ararslan
Copy link
Collaborator

ararslan commented Jun 18, 2019

In order to make Nabla use ChainRules for sensitivities with full feature parity with the current implementation, we'll need to port the sensitivity definitions here (i.e. the methods) and turn them into rrule methods in ChainRules.

How to port

Porting them over is actually pretty straightforward. Nabla's methods pass an Arg{i} argument that dictates which of the arguments to the function is being differentiated in the current method. For example,

f(a::Int, b::Int) = a + 2b + 1

# Derivative for the first argument, `a`
(::typeof(f), ::Type{Arg{1}}, p, y, ȳ, a::Int, b::Int) =# Derivative for the second, `b`
(::typeof(f), ::Type{Arg{2}}, p, y, ȳ, a::Int, b::Int) = 2

ChainRules rrules methods include both derivatives in a single method. So the above translates to

function rrule(::typeof(f), a::Int, b::Int)
    y = f(a, b)
    ∂a = Rule(ȳ -> ȳ)
    ∂b = Rule(ȳ -> 2ȳ)
    return y, (∂a, ∂b)
end

There are cases where a is purposefully not defined for a given Arg{i}; that denotes that there is no derivative with respect to that argument. In ChainRules, we express that by returning a DNERule() in place of the Rule. So if in the above example f was only differentiable with respect to b, the rrule would instead look like

function rrule(::typeof(f), a::Int, b::Int)
    y = f(a, b)
    ∂b = Rule(ȳ -> 2ȳ)
    return y, (DNERule(), ∂b)
end

Also note that the derivatives for the various arguments can share intermediate computation. That can go into the body of the rrule method itself, with the defined variables captured in the closures in the Rules.

There are some cases where Nabla defines custom methods for updating the tape with a given sensitivity. Those are expressed as methods of with the tape value as the first argument. ChainRules does this differently: if you have a special way in which you'd like to accumulate a sensitivity to a given value, you provide a second argument to Rule that's another function that takes arguments (x̄, ȳ). This is used by the ChainRules.accumulate!(value, rule, args...) method. Just like Nabla, if no such special method exists for updating, a generic fallback is used.

Progress

Below is a list of all of the basic methods. More items are finished than are currently checked as of this writing; as you find that ChainRules does indeed include a corresponding rrule method and the sensitivity definition it uses looks correct, please check these off.

Note that this list does not include methods which update the tape directly!

  • *(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • *(::Adjoint, ::Adjoint)
  • *(::Adjoint, ::StridedMatrix)
  • *(::Number, ::Number)
  • *(::StridedMatrix, ::Adjoint)
  • *(::StridedMatrix, ::StridedMatrix)
  • *(::StridedMatrix, ::Transpose)
  • *(::Transpose, ::StridedMatrix)
  • *(::Transpose, ::Transpose)
  • *(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
  • +(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • +(::AbstractArray{#s12<:Number,N} where N, ::UniformScaling{T<:Number})
  • +(::Number)
  • +(::Number, ::Number)
  • +(::UniformScaling{T<:Number}, ::AbstractArray{#s12<:Number,N} where N)
  • -(::AbstractArray{#s12<:Number,N} where N)
  • -(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • -(::Number)
  • -(::Number, ::Number)
  • /(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • /(::Number, ::Number)
  • /(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
  • Cholesky(::Union{LowerTriangular, UpperTriangular}, ::Union{Char, Symbol}, ::Integer)
  • Diagonal(::AbstractArray{#s12<:Number,1})
  • Diagonal(::AbstractArray{#s12<:Number,2})
  • LinearAlgebra.BLAS.asum(::Any)
  • LinearAlgebra.BLAS.asum(::Integer, ::Any, ::Integer)
  • LinearAlgebra.BLAS.dot(::Int64, ::StridedVector, ::Int64, ::StridedVector, ::Int64)
  • LinearAlgebra.BLAS.gemm(::Char, ::Char, ::StridedMatrix, ::StridedMatrix) (Implement sensitivities for BLAS.gemm JuliaDiff/ChainRules.jl#25)
  • LinearAlgebra.BLAS.gemm(::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedMatrix) (Implement sensitivities for BLAS.gemm JuliaDiff/ChainRules.jl#25)
  • LinearAlgebra.BLAS.gemv(::Char, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.gemv(::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.nrm2(::Any)
  • LinearAlgebra.BLAS.nrm2(::Integer, ::Any, ::Integer)
  • LinearAlgebra.BLAS.symm(::Char, ::Char, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.symm(::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.symv(::Char, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.symv(::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.trmm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.trmm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.trsm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedMatrix)
  • LinearAlgebra.BLAS.trsm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.BLAS.trsv(::Char, ::Char, ::Char, ::StridedMatrix, ::StridedVector)
  • LinearAlgebra.cholesky(::AbstractArray{T<:Number,2}) (Add an rrule for the Cholesky decomposition JuliaDiff/ChainRules.jl#44)
  • LinearAlgebra.det(::AbstractArray{#s12<:Number,N} where N)
  • LinearAlgebra.det(::Diagonal{#s77<:Number,V} where V<:AbstractArray{#s77<:Number,1})
  • LinearAlgebra.det(::LowerTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
  • LinearAlgebra.det(::UpperTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
  • LinearAlgebra.diag(::AbstractArray{#s12<:Number,2})
  • LinearAlgebra.diag(::AbstractArray{#s12<:Number,2}, ::Integer)
  • LinearAlgebra.dot(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • LinearAlgebra.logdet(::AbstractArray{#s12<:Number,N} where N)
  • LinearAlgebra.logdet(::Diagonal{#s77<:Number,V} where V<:AbstractArray{#s77<:Number,1})
  • LinearAlgebra.logdet(::LowerTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
  • LinearAlgebra.logdet(::UpperTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
  • LinearAlgebra.norm(::AbstractArray{#s12<:Number,N} where N) (Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)
  • LinearAlgebra.norm(::AbstractArray{#s12<:Number,N} where N, ::Number) (Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)
  • LinearAlgebra.norm(::Number) (Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)
  • LinearAlgebra.norm(::Number, ::Number) (Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)
  • LinearAlgebra.svd(::AbstractArray{T,2}) (Add SVD factorization rrule JuliaDiff/ChainRules.jl#31)
  • LinearAlgebra.tr(::AbstractArray{#s12<:Number,N} where N)
  • LowerTriangular(::AbstractArray{#s12<:Number,2})
  • SpecialFunctions.airyai(::Number)
  • SpecialFunctions.airyaiprime(::Number)
  • SpecialFunctions.airybi(::Number)
  • SpecialFunctions.airybiprime(::Number)
  • SpecialFunctions.besseli(::Number, ::Number)
  • SpecialFunctions.besselj(::Number, ::Number)
  • SpecialFunctions.besselj0(::Number)
  • SpecialFunctions.besselj1(::Number)
  • SpecialFunctions.besselk(::Number, ::Number)
  • SpecialFunctions.bessely(::Number, ::Number)
  • SpecialFunctions.bessely0(::Number)
  • SpecialFunctions.bessely1(::Number)
  • SpecialFunctions.beta(::Number, ::Number)
  • SpecialFunctions.dawson(::Number)
  • SpecialFunctions.digamma(::Number)
  • SpecialFunctions.erf(::Number)
  • SpecialFunctions.erfc(::Number)
  • SpecialFunctions.erfcinv(::Number)
  • SpecialFunctions.erfcx(::Number)
  • SpecialFunctions.erfi(::Number)
  • SpecialFunctions.erfinv(::Number)
  • SpecialFunctions.gamma(::Number)
  • SpecialFunctions.invdigamma(::Number)
  • SpecialFunctions.lbeta(::Number, ::Number)
  • SpecialFunctions.lgamma(::Number)
  • SpecialFunctions.polygamma(::Number, ::Number)
  • SpecialFunctions.trigamma(::Number)
  • Statistics.mean(::AbstractArray{#s74<:Number,N} where N) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • Statistics.mean(::Function, ::AbstractArray{#s77<:Number,N} where N) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • UpperTriangular(::AbstractArray{#s12<:Number,2})
  • \(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • \(::Number, ::Number)
  • \(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
  • ^(::Number, ::Number)
  • abs(::Number)
  • abs2(::Number)
  • acos(::Number)
  • acosd(::Number)
  • acosh(::Number)
  • acot(::Number)
  • acotd(::Number)
  • acoth(::Number)
  • acsc(::Number)
  • acscd(::Number)
  • acsch(::Number)
  • adjoint(::AbstractArray{#s12<:Number,N} where N)
  • adjoint(::Number)
  • asec(::Number)
  • asecd(::Number)
  • asech(::Number)
  • asin(::Number)
  • asind(::Number)
  • asinh(::Number)
  • atand(::Number)
  • atanh(::Number)
  • broadcast(::Any, ::Vararg{Any,N})
  • cbrt(::Number)
  • copy(::Any)
  • cos(::Number)
  • cosd(::Number)
  • cosh(::Number)
  • cospi(::Number)
  • cot(::Number)
  • cotd(::Number)
  • coth(::Number)
  • csc(::Number)
  • cscd(::Number)
  • csch(::Number)
  • deg2rad(::Number)
  • exp(::AbstractArray{T,2})
  • exp(::Number)
  • exp10(::Number)
  • exp2(::Number)
  • expm1(::Number)
  • fill(::Any, ::Vararg{Any,N}) (Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)
  • float(::Any)
  • getindex(::Any, ::Vararg{Any,N})
  • getproperty(::Cholesky{T,S} where S<:(AbstractArray{T,2} where T), ::Symbol) (Add an rrule for the Cholesky decomposition JuliaDiff/ChainRules.jl#44)
  • getproperty(::SVD{T,Tr,M} where M<:(AbstractArray{T,N} where N) where Tr, ::Symbol) (Add SVD factorization rrule JuliaDiff/ChainRules.jl#31)
  • hcat(::Vararg{AbstractArray,N}) (Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)
  • hypot(::Number, ::Number)
  • identity(::Any) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • inv(::AbstractArray{#s12<:Number,N} where N)
  • inv(::Number)
  • kron(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
  • log(::Number)
  • log10(::Number)
  • log2(::Number)
  • map(::Function, ::Vararg{AbstractArray{#s12,N} where N where #s12<:Number,N}) (Add rrule for map and expand the testing framework JuliaDiff/ChainRules.jl#56)
  • mapfoldl(::Any, ::Union{typeof(+), typeof(add_sum)}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • mapfoldr(::Any, ::Union{typeof(+), typeof(add_sum)}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • mapreduce(::Any, ::Union{typeof(+), typeof(add_sum)}, ::AbstractArray{#s38<:Number,N} where N) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • max(::Number, ::Number)
  • min(::Number, ::Number)
  • rad2deg(::Number)
  • reshape(::AbstractArray{#s12<:Number,N} where N, ::Vararg{Any,N}) (Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)
  • sec(::Number)
  • secd(::Number)
  • sech(::Number)
  • sin(::Number)
  • sind(::Number)
  • sinh(::Number)
  • sinpi(::Number)
  • sqrt(::Number)
  • sum(::AbstractArray{#s68<:Number,N} where N) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • sum(::Function, ::AbstractArray{#s64<:Number,N} where N) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • sum(::typeof(abs2), ::AbstractArray{#s73<:Number,N} where N) (Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)
  • tan(::Number)
  • tand(::Number)
  • tanh(::Number)
  • transpose(::AbstractArray{#s12<:Number,N} where N)
  • transpose(::Number)
  • vcat(::Vararg{AbstractArray,N}) (Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant