Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Compose-able self-contained units #37

Open
willtebbutt opened this issue Aug 17, 2017 · 0 comments
Open

Compose-able self-contained units #37

willtebbutt opened this issue Aug 17, 2017 · 0 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Aug 17, 2017

It would be useful to support stuff like this:
denizyuret/Knet.jl#144

We can pretty much already support this due to the design of the function in core.jl, specifically the method

∇(y::Node{T}, ȳ::T) where T

It allows one to pass in reverse-mode sensitivities and doesn't restrict y to be scalar, meaning that we can chain together Tapes. This is useful because it allows us to define primitive-like functions / functors which can use Nabla to know how to differentiate themselves, and use these inside larger functions / functors in the same way that you would a primitive.

If, for example, one wished to create a self-contained differentiable layer for an MLP which handles it's own parameters one could create a functor and make it a primitive. You could do something along the following lines:

mutable struct Layer{T:<AbstractMatrix, V<:Node}
    W::T
    ∇W::T
    out::V
end
function (l::Layer)(x::AbstractVector)
    W_, x_ = Leaf.(Tape(), (l.W, x))
    y_ = tanh.(W_ * x_)
    l.out = y_ # Log the result of the forward-pass and the intermediate qtts required for the reverse-pass.
    return y_.val
end
@explicit_intercepts Layer Tuple{∇Real}
function ∇(l::Layer, ::Type{Arg{1}}, p, y, ȳ, x)
    ∇layer = ∇(l.out, ȳ) # Perform the reverse-pass.
    l.∇W = ∇layer[1] # Log the gradient of `W` on this pass through.
    return ∇(l.out)[2] # Return the reverse-mode sensitivity w.r.t. `x`.
end

This probably won't compile (and I'm not actually sure about some of the syntax re functors), but it illustrates the general idea: a functor could be used to implement a primitive-like block of code which keeps track of it's own parameters and knows how to differentiate itself, without having to hand-code the gradients. (There are a few technical issues of consistency and scoping that I have brushed over e.g. how do you ensure that the x passed in to is the same x passed to the Layer functor initially, but I'm going to assume that these could be resolved with some (relatively) minor changes to core.jl).

This seems generally like a useful thing to be able to do, so it would probably make sense to figure out the details and partially automate the process with a macro.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant