This repository has been archived by the owner on Apr 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
core.jl
293 lines (256 loc) · 9.77 KB
/
core.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
using DualNumbers
import Base: push!, length, show, getindex, setindex!, eachindex, isassigned,
isapprox, zero, one, lastindex
export Leaf, Tape, Node, Branch, ∇
""" Basic unit on the computational graph."""
abstract type Node{T} end
""" A topologically ordered collection of Nodes. """
struct Tape
tape::Vector{Any}
Tape() = new(Vector{Any}())
Tape(N::Int) = new(Vector{Any}(undef, N))
end
function show(io::IO, t::Tape)
n = length(t)
print(io, "Tape with ", n, " element", n == 1 ? "" : "s", n > 0 ? ":" : "")
for i in eachindex(t)
print(io, "\n [", i, "]: ")
if isassigned(tape(t), i)
show(io, t[i])
else
print(io, "#undef")
end
end
end
@inline getindex(t::Tape, n::Int) = getindex(tape(t), n)
@inline getindex(t::Tape, node::Node) = getindex(t, pos(node))
@inline lastindex(t::Tape) = length(t)
@inline setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t)
@inline eachindex(t::Tape) = eachindex(tape(t))
@inline length(t::Tape) = length(tape(t))
@inline push!(t::Tape, node::Node) = (push!(tape(t), node); t)
@inline isassigned(t::Tape, n::Int) = isassigned(tape(t), n)
@inline isassigned(t::Tape, node::Node) = isassigned(t, pos(node))
# Make `Tape`s broadcast as scalars without a warning on 0.7
Base.Broadcast.broadcastable(tape::Tape) = Ref(tape)
"""
An element at the 'bottom' of the computational graph.
Fields:
val - the value of the node.
tape - The Tape to which this Leaf is assigned.
pos - the location of this Leaf in the tape to which it is assigned.
"""
struct Leaf{T} <: Node{T}
val::T
tape::Tape
pos::Int
end
function Leaf(tape::Tape, val)
leaf = Leaf(val, tape, length(tape) + 1)
push!(tape, leaf)
return leaf
end
show(io::IO, tape::Leaf{T}) where T = print(io, "Leaf{$T} $(unbox(tape))")
show(io::IO, tape::Leaf{T}) where T<:AbstractArray = print(io, "Leaf{$T} $(size(unbox(tape)))")
"""
A Branch is a Node with parents (args).
Fields:
val - the value of this node produced in the forward pass.
f - the function used to generate this Node.
args - Values indicating which elements in the tape will require updating by this node.
tape - The Tape to which this Branch is assigned.
pos - the location of this Branch in the tape to which it is assigned.
"""
struct Branch{T} <: Node{T}
val::T
f
args::Tuple
kwargs::NamedTuple
tape::Tape
pos::Int
end
function Branch(f, args::Tuple, tape::Tape; kwargs...)
unboxed = unbox.(args)
branch = Branch(
f(unboxed...; kwargs...), f, args, getfield(kwargs, :data), tape, length(tape) + 1
)
push!(tape, branch)
return branch
end
show(io::IO, branch::Branch{T}) where T =
print(io, "Branch{$T} $(unbox(branch)) f=$(getfield(branch, :f))")
show(io::IO, branch::Branch{T}) where T<:AbstractArray =
print(io, "Branch{$T} $(size(unbox(branch))) f=$(getfield(branch, :f))")
"""
tape(x::Node)
tape(x::Tape)
Retrieve the `Tape` in a `Node`, or the underyling vector in a `Tape`.
"""
tape(x::Union{Node,Tape}) = getfield(x, :tape)
"""
pos(x::Node)
pos(x)
Location of Node on tape. -1 if not a Node object.
"""
pos(x::Node) = getfield(x, :pos)
pos(x) = -1
"""
unbox(x::Node)
unbox(x)
Get `.val` if `x` is a Node, otherwise is equivalent to `identity`.
"""
unbox(x::Node) = getfield(x, :val)
unbox(x) = x
isapprox(n::Node, f) = unbox(n) ≈ f
isapprox(f, n::Node) = n ≈ f
isapprox(n::Node, f::Node) = unbox(n) ≈ unbox(f)
zero(n::Node) = zero(unbox(n))
one(n::Node) = one(unbox(n))
# Leafs do nothing, Branches compute their own sensitivities and update others.
@inline propagate(y::Leaf, rvs_tape::Tape) = nothing
function propagate(y::Branch, rvs_tape::Tape)
tape = Nabla.tape(rvs_tape)
ȳ, f = tape[pos(y)], getfield(y, :f)
args = getfield(y, :args)
kwargs = getfield(y, :kwargs)
xs, xids = map(unbox, args), map(pos, args)
p = preprocess(f, unbox(y), ȳ, xs...)
for j in eachindex(xs)
x, xid = xs[j], xids[j]
if xid > 0
tape[xid] = isassigned(tape, xid) ?
∇(tape[xid], f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) :
∇(f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...)
end
end
return nothing
end
function propagate(fwd_tape::Tape, rvs_tape::Tape)
for n in eachindex(rvs_tape)
δ = length(rvs_tape) - n + 1
isassigned(tape(rvs_tape), δ) && propagate(fwd_tape[δ], rvs_tape)
end
return rvs_tape
end
""" Initialise a Tape appropriately for being used as a reverse-tape. """
function reverse_tape(y::Node, ȳ)
tape = Tape(pos(y))
tape[end] = ȳ
return tape
end
""" Used to flag which argument is being specified in x̄. """
struct Arg{N} end
"""
∇(y::Node{<:∇Scalar})
∇(y::Node{T}, ȳ::T) where T
Return a `Tape` object which can be indexed using `Node`s, each element of which contains
the result of multiplying `ȳ` by the transpose of the Jacobian of the function specified by
the `Tape` object in `y`. If `y` is a scalar and `ȳ = 1` then this is equivalent to
computing the gradient of `y` w.r.t. each of the elements in the `Tape`.
∇(f::Function, ::Type{Arg{N}}, p, y, ȳ, x...)
∇(x̄, f::Function, ::Type{Arg{N}}, p, y, ȳ, x...)
To implement a new reverse-mode sensitivity for the `N^{th}` argument of function `f`. p
is the output of `preprocess`. `x1`, `x2`,... are the inputs to the function, `y` is its
output and `ȳ` the reverse-mode sensitivity of `y`.
"""
∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ))
@inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y)))
# This is a fallback method where we don't necessarily know what we'll be adding and whether
# we can update the value in-place, so we'll try to be clever and dispatch.
@inline function ∇(x̄, f, ::Type{Arg{N}}, args...; kwargs...) where N
return update!(x̄, ∇(f, Arg{N}, args...; kwargs...))
end
# Update regular arrays in-place. Structured array types should not be updated in-place,
# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674),
# so we'll only permit mutating addition for `Array`s, e.g. `Vector` and `Matrix`.
# Mixed array and scalar adds should not occur, as sensitivities should always have the
# same shape, so we won't bother allowing e.g. updating an array with a scalar on the RHS.
update!(x̄::Array{T,N}, y::AbstractArray{S,N}) where {T,S,N} = x̄ .+= y
# Fall back to using regular addition
update!(x̄, y) = x̄ + y
"""
∇(f; get_output::Bool=false)
Returns a function which, when evaluated with arguments that are accepted by `f`, will
return the gradient w.r.t. each of the arguments. If `get_output` is `true`, the result
of calling `f` on the given arguments is also returned.
"""
function ∇(f; get_output::Bool=false)
return function(args...; kwargs...)
args_ = Leaf.(Tape(), args)
y = f(args_...; kwargs...)
if y isa Node
∇f = ∇(y)
∇args = map(args_, args) do arg_, arg
isassigned(∇f, arg_) ? ∇f[arg_] : zero(arg)
end
else
∇args = zero.(args)
end
return get_output ? (y, ∇args) : ∇args
end
end
# """
# ∇(f::Function)
# Returns a function which, when evaluated with arguments that are accepted by `f` (`x`),
# will return a Tuple, the first element of which is the output of the function `f` and then
# second element of which is (yet another) function `g`. `g` can either be evaluated with no
# arguments, in which case it will return the gradient of `f` evaluated at `x`.
# Alternatively, it can be evaluated with arguments of the same type and shape as the output
# of `f(x)`, in which case it is equivalent to multiplying them 'from the left' by the
# Jacobian ∂(f(x)) / ∂x.
# """
# function ∇(f::Function)
# return function(args...)
# args_ = Leaf.(Tape(), args)
# y = f(args_...)
# ∇fx = (ȳ)->∇
# end
# end
# A collection of methods for initialising nested indexable containers to zero.
for (f_name, scalar_init, array_init) in
zip((:zerod_container, :oned_container, :randned_container),
(:zero, :one, nothing),
(:zeros, :ones, nothing))
if scalar_init !== nothing
@eval @inline $f_name(x::Number) = $scalar_init(x)
end
if array_init !== nothing
@eval @inline $f_name(x::AbstractArray{<:Real}) = $array_init(eltype(x), size(x))
end
eval(quote
@inline $f_name(x::Tuple) = map($f_name, x)
@inline function $f_name(x)
y = Base.copy(x)
for n in eachindex(y)
@inbounds y[n] = $f_name(y[n])
end
return y
end
$f_name(x::Ref) = Ref($f_name(x[]))
end)
end
@inline randned_container(x::Number) = randn(typeof(x))
@inline randned_container(x::AbstractArray{<:Real}) = randn(eltype(x), size(x)...)
for T in (:Diagonal, :UpperTriangular, :LowerTriangular)
@eval @inline randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...))
end
# Bare-bones FMAD implementation based on DualNumbers. Accepts a Tuple of args and returns
# a Tuple of gradients. Currently scales almost exactly linearly with the number of inputs.
# The coefficient of this scaling could be improved by implementing a version of DualNumbers
# which computes from multiple seeds at the same time.
function dual_call_expr(f, x::Type{<:Tuple}, ::Type{Type{Val{n}}}) where n
dual_call = Expr(:call, :f)
for m in 1:Base.length(x.parameters)
push!(dual_call.args, n == m ? :(Dual(x[$m], 1)) : :(x[$m]))
end
return :(dualpart($dual_call))
end
@generated fmad(f, x, n) = dual_call_expr(f, x, n)
function fmad_expr(f, x::Type{<:Tuple})
body = Expr(:tuple)
for n in 1:Base.length(x.parameters)
push!(body.args, dual_call_expr(f, x, Type{Val{n}}))
end
return body
end
@generated fmad(f, x) = fmad_expr(f, x)