Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Merge pull request #189 from invenia/ox/chainrules
Browse files Browse the repository at this point in the history
Use ChainRules
  • Loading branch information
oxinabox committed Jul 5, 2021
2 parents f251e9e + 970893a commit 8d3dc2b
Show file tree
Hide file tree
Showing 42 changed files with 671 additions and 1,446 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Expand Up @@ -30,9 +30,9 @@ jobs:
- os: windows-latest
arch: x86
include:
# Add a 1.5 job because that's what Invenia actually uses
# Add a 1.6 job because that's what Invenia actually uses
- os: ubuntu-latest
version: 1.5
version: 1.6
arch: x64
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1'
version: '1.6'
- run: |
julia --project=docs -e '
using Pkg
Expand Down
10 changes: 7 additions & 3 deletions .gitignore
@@ -1,7 +1,11 @@
*.pdf
*.DS_Store
*.jl.cov
*.jl.*.cov
*.jl.mem
*.pdf
*.DS_Store
Manifest.toml
docs/build/
docs/build
docs/site
docs/src/assets/chainrules.css
docs/src/assets/indigo.css
.vscode/settings.json
21 changes: 14 additions & 7 deletions Project.toml
@@ -1,20 +1,27 @@
name = "Nabla"
uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78"
version = "0.12.3"
version = "0.13.0"

[deps]
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
DiffRules = "0.0, 1"
DualNumbers = "0.6"
FDM = "^0.6"
SpecialFunctions = "0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1"
ChainRules = "0.8"
ChainRulesCore = "0.10.9"
ChainRulesOverloadGeneration = "0.1.2"
ExprTools = "0.1.4"
FDM = "0.6.1"
ForwardDiff = "0.10.12"
SpecialFunctions = "1.5.1"
julia = "^1.3"

[extras]
Expand Down
4 changes: 3 additions & 1 deletion docs/Project.toml
@@ -1,5 +1,7 @@
[deps]
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Nabla = "49c96f43-aa6d-5a04-a506-44c7070ebe78"

[compat]
Documenter = "~0.19"
Documenter = "0.27"
25 changes: 13 additions & 12 deletions docs/make.jl
@@ -1,25 +1,26 @@
using Documenter, Nabla
using Documenter
using DocThemeIndigo
using Nabla

const indigo = DocThemeIndigo.install(Nabla)
makedocs(
modules=[Nabla],
format=:html,
format=Documenter.HTML(
prettyurls=false,
assets=[indigo],
),
sitename="Nabla.jl",
authors="Invenia Labs",
pages=[
"Home" => "index.md",
"API" => "pages/api.md",
"Custom Sensitivities" => "pages/custom.md",
"Details" => "pages/autodiff.md",
],
sitename="Nabla.jl",
authors="Invenia Labs",
assets=[
"assets/invenia.css",
],
)


deploydocs(
repo = "github.com/invenia/Nabla.jl.git",
julia = "1.0",
target = "build",
deps = nothing,
make = nothing,
)
push_preview=true,
)
75 changes: 0 additions & 75 deletions docs/src/assets/invenia.css

This file was deleted.

Binary file added docs/src/assets/logo.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 16 additions & 1 deletion docs/src/pages/custom.md
@@ -1,4 +1,19 @@
# Custom Sensitivities
# Custom Sensitivities

!!! note "Prefer to use ChainRulesCore to define custom sensitivities"
Nabla supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is preferred to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Nabla.
**It is also much easier, than the Nabla specific way**.
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/).
To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`.
See the [ChainRules project's documentation for more information](https://www.juliadiff.org/ChainRulesCore.jl/stable/).
**If you are defining your custom sensitivities using ChainRulesCore then you do not need to read this page**, and can consider it as documenting a legacy feature.

This page exists to describe how Nabla works, and how sensitivities can be directly defined for Nabla.
Defining sensitivities this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Nabla works.
It allows for specific definitions of sensitivities that are only defined for Nabla (which might work differently to more generic definitions defined for all AD).

# Legacy Method

Part of the power of Nabla is its extensibility, specifically in the form of defining
custom sensitivities for functions.
Expand Down
19 changes: 14 additions & 5 deletions src/Nabla.jl
@@ -1,9 +1,14 @@
__precompile__()

module Nabla

using SpecialFunctions
using ChainRules
using ChainRulesCore
using ChainRulesOverloadGeneration
using ExprTools: ExprTools
using ForwardDiff: ForwardDiff
using LinearAlgebra
using Random
using SpecialFunctions
using Statistics

# Some aliases used repeatedly throughout the package.
Expand Down Expand Up @@ -39,10 +44,12 @@ module Nabla
# into a separate module at some point.
include("finite_differencing.jl")

# Sensitivities via ChainRules
include("sensitivities/chainrules.jl")

# Sensitivities for the basics.
include("sensitivities/indexing.jl")
include("sensitivities/scalar.jl")
include("sensitivities/array.jl")

# Sensitivities for functionals.
include("sensitivities/functional/functional.jl")
Expand All @@ -52,14 +59,16 @@ module Nabla
# Linear algebra optimisations.
include("sensitivities/linalg/generic.jl")
include("sensitivities/linalg/symmetric.jl")
include("sensitivities/linalg/strided.jl")
include("sensitivities/linalg/blas.jl")
include("sensitivities/linalg/diagonal.jl")
include("sensitivities/linalg/triangular.jl")
include("sensitivities/linalg/factorization/cholesky.jl")
include("sensitivities/linalg/factorization/svd.jl")

# Checkpointing
include("checkpointing.jl")


# Link up to ChainRulesCore so rules are generated when new rrules are declared.
on_new_rule(generate_overload, rrule)

end # module Nabla
37 changes: 36 additions & 1 deletion src/code_transformation/util.jl
Expand Up @@ -40,6 +40,23 @@ function unionise_type(tp::Union{Symbol, Expr})
return replace_vararg(:(Union{$_tp, Node{<:$tp_clean}}), (_tp, _info))
end

"""
node_type(tp::Union{Symbol, Expr})
Returns an expression for the `Node{<:tp}`. e.g.
`node_type(:Real)` returns `:(Node{<:Real}})`.
Correctly `:(Vararg{Real})` becomes `:(Vararg{Node{<:Real}})`
This is a lot like [`unionise_type`](ref) but it doesn't permit the original type anymore.
"""
function node_type(tp::Union{Symbol, Expr})
(_tp, _info) = remove_vararg(tp)
tp_clean = (isa(_tp, Expr) && _tp.head == Symbol("<:")) ? _tp.args[1] : _tp
return replace_vararg(:(Node{<:$tp_clean}), (_tp, _info))
end


"""
replace_body(unionall::Union{Symbol, Expr}, replacement::Union{Symbol, Expr})
Expand Down Expand Up @@ -91,6 +108,24 @@ function remove_vararg(typ::Expr)
if isa_vararg(typ)
body = get_body(typ)
new_typ = replace_body(typ, body.args[2])

# This is a bit ugly:
# handle interally `where N` from `typ = :(Vararg{FOO, N} where N)` which results in
# `body = :(Vararg{FOO, N})` and `new_type = Foo where N`, we don't need to keep it
# at all, the `where N` wasn't doing anything to begin with, so we just strip it out
if Meta.isexpr(new_typ, :where) && Meta.isexpr(body, :curly, 3)
@assert body.args[1] == :Vararg
T = body.args[2]
N = body.args[3]
if new_typ.args == [T, N] # ($T where $N)
body = :(Vararg{T})
new_typ = T
elseif T == new_typ.args[1] && N new_typ.args[2:end] # ($T where {?, $N, ?})
body = :(Vararg{T})
filter!(!isequal(N), new_typ.args)
end
end

vararg_info = length(body.args) == 3 ? body.args[3] : :Vararg
return new_typ, vararg_info
else
Expand All @@ -107,7 +142,7 @@ Convert `typ` to the `Vararg` containing elements of type `typ` specified by
replace_vararg(typ::SymOrExpr, vararg_info::Tuple) =
vararg_info[2] == :nothing ?
typ :
vararg_info[2] == :no_N || vararg_info[2] == :Vararg ?
vararg_info[2] == :no_N || vararg_info[2] == :Vararg ? #TODO: :no_N is impossible now?
replace_body(typ, :(Vararg{$(get_body(typ))})) :
replace_body(typ, :(Vararg{$(get_body(typ)), $(vararg_info[2])}))

Expand Down

2 comments on commit 8d3dc2b

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/40306

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.0 -m "<description of version>" 8d3dc2bafaae0eb9d26b64dfabae3253c9ec01f4
git push origin v0.13.0

Please sign in to comment.