Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Sensitivity for kron (#85)
Browse files Browse the repository at this point in the history
* sensitivity for `kron`

* in-place versions for kron sensitivities

* fixed typo in a comment

* fixed overflow bug

* kron efficiency improvements.
  • Loading branch information
wesselb committed Mar 3, 2018
1 parent bd9460a commit 8c4446a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
34 changes: 34 additions & 0 deletions src/sensitivities/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,37 @@ for (f, T_A, T_B, T_Y, Ā, B̄) in binary_linalg_optimisations
@eval (::typeof($f), ::Type{Arg{1}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $
@eval (::typeof($f), ::Type{Arg{2}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $
end

# Sensitivities for the Kronecker product:
import Base.kron
@explicit_intercepts kron Tuple{A, A}

# The allocating versions simply allocate and then call the in-place versions.
(::typeof(kron), ::Type{Arg{1}}, p, Y::A, Ȳ::A, A::A, B::A) =
(zeros(A), kron, Arg{1}, p, Y, Ȳ, A, B)
(::typeof(kron), ::Type{Arg{2}}, p, Y::A, Ȳ::A, A::A, B::A) =
(zeros(B), kron, Arg{2}, p, Y, Ȳ, A, B)

function (Ā::A, ::typeof(kron), ::Type{Arg{1}}, p, Y::A, Ȳ::A, A::A, B::A)
(I, J), (K, L), m = size(A), size(B), length(Y)
for j = reverse(1:J), l = reverse(1:L), i = reverse(1:I)
aij, āij = A[i, j], Ā[i, j]
for k = reverse(1:K)
āij += Ȳ[m] * B[k, l]
m -= 1
end
Ā[i, j] = āij
end
return
end
function (B̄::A, ::typeof(kron), ::Type{Arg{2}}, p, Y::A, Ȳ::A, A::A, B::A)
(I, J), (K, L), m = size(A), size(B), length(Y)
for j = reverse(1:J), l = reverse(1:L), i = reverse(1:I)
aij = A[i, j]
for k = reverse(1:K)
B̄[k, l] += Ȳ[m] * aij
m -= 1
end
end
return
end
29 changes: 17 additions & 12 deletions test/sensitivities/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,28 @@
# Generate random test quantities for specific types.
∇Arrays = Union{Type{∇Array}, Type{∇ArrayOrScalar}}
trandn(rng::AbstractRNG, ::∇Arrays) = randn(rng, N, N)
trandn2(rng::AbstractRNG, ::∇Arrays) = randn(rng, N^2, N^2)
trandn(rng::AbstractRNG, ::Type{∇Scalar}) = randn(rng)
trand(rng::AbstractRNG, ::∇Arrays) = rand(rng, N, N)
trand(rng::AbstractRNG, ::Type{∇Scalar}) = rand(rng)

# Test unary linalg sensitivities.
for (f, T_In, T_Out, X̄, bounds) in Nabla.unary_linalg_optimisations
Z = trand(rng, T_In) .* (bounds[2] - bounds[1]) + bounds[1]
X = Z'Z + 1e-6 * one(Z)
Ȳ, V = eval(f)(X), trandn(rng, T_In)
V'V + 1e-6 * one(V)
@test check_errs(eval(f), Ȳ, X, 1e-2 .* V)
end
for _ in 1:5
# Test unary linalg sensitivities.
for (f, T_In, T_Out, X̄, bounds) in Nabla.unary_linalg_optimisations
Z = trand(rng, T_In) .* (bounds[2] - bounds[1]) + bounds[1]
X = Z'Z + 1e-6 * one(Z)
Ȳ, V = eval(f)(X), trandn(rng, T_In)
@test check_errs(eval(f), Ȳ, X, 1e-1 .* V)
end

# Test binary linalg sensitivities.
for (f, T_A, T_B, T_Y, Ā, B̄) in Nabla.binary_linalg_optimisations
A, B, VA, VB = trandn.(rng, (T_A, T_B, T_A, T_B))
@test check_errs(eval(f), eval(f)(A, B), (A, B), (VA, VB))
# Test binary linalg sensitivities.
for (f, T_A, T_B, T_Y, Ā, B̄) in Nabla.binary_linalg_optimisations
A, B, VA, VB = trandn.(rng, (T_A, T_B, T_A, T_B))
@test check_errs(eval(f), eval(f)(A, B), (A, B), (VA, VB))
end
A, B, VA, VB = trandn.(rng, (∇Array, ∇Array, ∇Array, ∇Array))
@test check_errs(kron, kron(A, B), (A, B), (VA, VB))
end

end
end

0 comments on commit 8c4446a

Please sign in to comment.