Skip to content

Commit

Permalink
fix some BlockDiagIEB AD issues
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Oct 5, 2023
1 parent 469e1f7 commit 8788418
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/field_vectors.jl
Expand Up @@ -66,25 +66,25 @@ promote_rule(::Type{F}, ::Type{<:Scalar}) where {F<:Field} = F
end

@auto_adjoint function sqrt(A::SA) where {SA<:StaticMatrix{2,2,<:DiagOp}}
a, c, b, d = A[1,1], A[2,1], A[2,1], A[2,2]
a, c, b, d = A[1,1], A[2,1], A[1,2], A[2,2]
s = sqrt(a*d-b*c)
t = pinv(sqrt(a+(d+2s)))
SA([t*(a+s) t*b; t*c t*(d+s)])
end

@auto_adjoint function det(A::StaticMatrix{2,2,<:DiagOp})
a, c, b, d = A[1,1], A[2,1], A[2,1], A[2,2]
a, c, b, d = A[1,1], A[2,1], A[1,2], A[2,2]
a*d-b*c
end

@auto_adjoint function pinv(A::SA) where {SA<:StaticMatrix{2,2,<:DiagOp}}
a, c, b, d = A[1,1], A[2,1], A[2,1], A[2,2]
a, c, b, d = A[1,1], A[2,1], A[1,2], A[2,2]
idet = pinv(a*d-b*c)
SA([d*idet -(b*idet); -(c*idet) a*idet])
end

function pinv!(dst::StaticMatrix{2,2,<:DiagOp}, src::StaticMatrix{2,2,<:DiagOp})
a, c, b, d = src[1,1], src[2,1], src[2,1], src[2,2]
a, c, b, d = src[1,1], src[2,1], src[1,2], src[2,2]
det⁻¹ = pinv(@. a*d-b*c)
@. dst[1,1] = det⁻¹ * d
@. dst[1,2] = -det⁻¹ * b
Expand Down
33 changes: 23 additions & 10 deletions src/specialops.jl
Expand Up @@ -56,12 +56,14 @@ end
# We store the 2x2 block as a 2x2 SMatrix, ΣTE, so that we can easily
# call sqrt/inv on it, and the ΣBB block separately as ΣB. This type
# is generic with regards to the field type, F.
struct BlockDiagIEB{P,T,D1,D2} <: ImplicitOp{T}
ΣTE :: SizedMatrix{2,2,D1,2,Matrix{D1}}
ΣB :: D2
function BlockDiagIEB(ΣTE::AbstractMatrix{D1}, ΣB::D2) where {T1,T2,P,F1<:BaseFourier{P},F2<:BaseFourier{P},D1<:Diagonal{T1,F1},D2<:Diagonal{T2,F2}}
T = promote_type(T1, T2)
new{P,T,D1,D2}(ΣTE, ΣB)
struct BlockDiagIEB{P,T,DTE,DB} <: ImplicitOp{T}
ΣTE :: SizedMatrix{2,2,DTE,2,Matrix{DTE}}
ΣB :: DB
function BlockDiagIEB(ΣTE::AbstractMatrix, ΣB)
ΣTE = SizedMatrix{2,2}(ΣTE)
T = promote_type(map(d -> d isa Diagonal ? eltype(d) : Union{}, (ΣTE...,ΣB))...)
P = promote_type(map(d -> d isa Diagonal ? typeof(diag(d).metadata) : Union{}, (ΣTE...,ΣB))...)
new{P,T,eltype(ΣTE),typeof(ΣB)}(ΣTE, ΣB)
end
end
@adjoint function BlockDiagIEB(ΣTE, ΣB)
Expand All @@ -83,23 +85,24 @@ end
size(L::BlockDiagIEB) = 3 .* size(L.ΣB)
adjoint(L::BlockDiagIEB) = BlockDiagIEB(adjoint(L.ΣTE), adjoint(L.ΣB))
sqrt(L::BlockDiagIEB) = BlockDiagIEB(sqrt(L.ΣTE), sqrt(L.ΣB))
@auto_adjoint pinv(L::BlockDiagIEB) = BlockDiagIEB(pinv(L.ΣTE), pinv(L.ΣB))
pinv(L::BlockDiagIEB) = BlockDiagIEB(pinv(L.ΣTE), pinv(L.ΣB))
diag(L::BlockDiagIEB{P}) where {P} = BaseIEBFourier{P}(diag(L.ΣTE[1,1]), diag(L.ΣTE[2,2]), diag(L.ΣB))
similar(L::BlockDiagIEB) = BlockDiagIEB(similar.(L.ΣTE), similar(L.ΣB))
get_storage(L::BlockDiagIEB) = get_storage(L.ΣB)
adapt_structure(storage, L::BlockDiagIEB) = BlockDiagIEB(adapt.(Ref(storage), L.ΣTE), adapt(storage, L.ΣB))
simulate(rng::AbstractRNG, L::BlockDiagIEB; Nbatch=()) = sqrt(L) * randn!(rng, similar(diag(L), Nbatch...))
logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB)
@auto_adjoint logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB)
# arithmetic
*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]*D[:I] L.ΣTE[1,2]*D[:E]; L.ΣTE[2,1]*D[:I] L.ΣTE[2,2]*D[:E]], L.ΣB*D[:B])
*(D::DiagOp{<:BaseIEBFourier}, L::BlockDiagIEB) = BlockDiagIEB([L.ΣTE[1,1]*D[:I] L.ΣTE[1,2]*D[:I]; L.ΣTE[2,1]*D[:E] L.ΣTE[2,2]*D[:E]], L.ΣB*D[:B])
+(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB+D[:B])
*(La::BlockDiagIEB, Lb::BlockDiagIEB) = BlockDiagIEB(La.ΣTE * Lb.ΣTE, La.ΣB * Lb.ΣB)
+(La::BlockDiagIEB, Lb::BlockDiagIEB) = BlockDiagIEB(La.ΣTE + Lb.ΣTE, La.ΣB + Lb.ΣB)
-(L::BlockDiagIEB) = BlockDiagIEB(.-(L.ΣTE), -L.ΣB)
+(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB([(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U)
*(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB(L.ΣTE * λ, L.ΣB * λ)
@auto_adjoint *(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB([L.ΣTE[1,1]*λ L.ΣTE[1,2]*λ; L.ΣTE[2,1]*λ L.ΣTE[2,2]*λ], L.ΣB*λ)
+(U::UniformScaling{<:Scalar}, L::BlockDiagIEB) = L + U
*::Scalar, L::BlockDiagIEB) = L * λ
@auto_adjoint *::Scalar, L::BlockDiagIEB) = L * λ
# indexing
function getindex(L::BlockDiagIEB{P}, k::Symbol) where {P}
@match k begin
Expand All @@ -112,6 +115,16 @@ function getindex(L::BlockDiagIEB{P}, k::Symbol) where {P}
_ => throw(ArgumentError("Invalid BlockDiagIEB index: $k"))
end
end
@adjoint function Base.getproperty(L::BlockDiagIEB, k::Symbol)
function BlockDiagIEB_getproperty_pullback(Δ)
if k == :ΣTE
(BlockDiagIEB(Δ, zero(getfield(L,:ΣB))), nothing)
else
(BlockDiagIEB(zero.(getfield(L,:ΣTE)), Δ), nothing)
end
end
return getfield(L, k), BlockDiagIEB_getproperty_pullback
end
# hashing
hash(L::BlockDiagIEB, h::UInt64) = foldr(hash, (typeof(L), L.ΣTE[1,1], L.ΣTE[1,2], L.ΣTE[2,2], L.ΣB), init=h)

Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Expand Up @@ -610,7 +610,9 @@ end
atol = pol==:IP ? 30 : 3
@test_real_gradient-> logpdf( ds; f = f + α * δf, ϕ = ϕ + α * δϕ), 0, atol=atol)
@test_real_gradient-> logpdf(Mixed(ds); f° =+ α * δf, ϕ° = ϕ° + α * δϕ), 0, atol=atol)

@test_real_gradient(r -> logpdf( ds; f, ϕ, θ=(;r)), T(0.1), atol=atol)
@test_real_gradient(r -> logpdf(Mixed(ds); f°, ϕ°, θ=(;r)), T(0.1), atol=atol)

end

end
Expand Down

0 comments on commit 8788418

Please sign in to comment.