Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Merge pull request #122 from invenia/ed/fix-diagonal
Browse files Browse the repository at this point in the history
Fix Diagonal sensitivity
  • Loading branch information
iamed2 committed Jan 8, 2019
2 parents b510378 + 004951f commit 30d5ccc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Nabla"
uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78"
version = "0.3.1"
version = "0.3.2"

[deps]
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Expand Down
8 changes: 4 additions & 4 deletions src/sensitivities/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,21 @@ function ∇(
::Type{Arg{1}},
p,
Y::∇ScalarDiag,
::ScalarDiag,
::AbstractMatrix,
x::∇AbstractVector,
)
return copyto!(similar(x), .diag)
return copyto!(similar(x), diag(Ȳ))
end
function (
::∇AbstractVector,
::Type{Diagonal},
::Type{Arg{1}},
p,
Y::∇ScalarDiag,
::ScalarDiag,
::AbstractMatrix,
x::∇AbstractVector,
)
return broadcast!(+, x̄, x̄, .diag)
return broadcast!(+, x̄, x̄, diag(Ȳ))
end

@explicit_intercepts Diagonal Tuple{∇AbstractMatrix}
Expand Down
9 changes: 9 additions & 0 deletions test/sensitivities/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@
@test check_errs(logdet, 10.0, A, VA)
end
end

# test that both diagonal implementations are correct and the same
let
f1(x) = sum(Diagonal(x))
f2(x) = sum(diagm(0 => x))

@test (f1)(ones(4))[1] == ones(4)
@test (f2)(ones(4))[1] == ones(4)
end
end

0 comments on commit 30d5ccc

Please sign in to comment.