Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Merge pull request #120 from invenia/ed/broadcast-styles
Browse files Browse the repository at this point in the history
Fix broadcast style promotion
  • Loading branch information
iamed2 committed Jan 7, 2019
2 parents 1104dcf + 149add4 commit 79b0a98
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@ map(f, x::AbstractArray{<:Number}) =

# Implementation of sensitivities w.r.t. `broadcast`.
using Base.Broadcast
using Base.Broadcast: Broadcasted, broadcastable, broadcast_axes, combine_axes
using Base.Broadcast: Broadcasted, broadcastable, broadcast_axes, combine_axes, result_style

struct NodeStyle{S} <: BroadcastStyle end

Base.BroadcastStyle(::Type{<:Node{T}}) where {T} = NodeStyle{BroadcastStyle(T)}()

Base.BroadcastStyle(::NodeStyle{S}, ::NodeStyle{S}) where {S} = NodeStyle{S}()
Base.BroadcastStyle(::NodeStyle{S1}, ::NodeStyle{S2}) where {S1,S2} =
NodeStyle{BroadcastStyle(S1, S2)}()
Base.BroadcastStyle(::NodeStyle{S}, B::BroadcastStyle) where {S} =
NodeStyle{BroadcastStyle(S, B)}()
function Base.BroadcastStyle(::NodeStyle{S1}, ::NodeStyle{S2}) where {S1,S2}
promoted = result_style(S1, S2)
promoted isa Broadcast.Unknown ? promoted : NodeStyle{promoted}()
end
function Base.BroadcastStyle(::NodeStyle{S}, B::BroadcastStyle) where {S}
promoted = result_style(S, B)
promoted isa Broadcast.Unknown ? promoted : NodeStyle{promoted}()
end

Broadcast.broadcast_axes(x::Node) = broadcast_axes(x.val)
Broadcast.broadcastable(x::Node) = x
Expand Down
18 changes: 18 additions & 0 deletions test/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ using DiffRules: diffrule, hasdiffrule
check_binary_dot(eval(f), x, Ref(rand(rng, y_distr)))
check_binary_dot(eval(f), rand(rng, x_distr), rand(rng, y_distr))
end

# test with other broadcast styles
let
a = Diagonal(ones(3))
b = ones(3, 3)
check_binary_dot(+, a, b)
check_binary_dot(+, b, a)
check_binary_dot(+, a, a)
end
end

# Check that the number of allocations which happen in the reverse pass of `map` and
Expand Down Expand Up @@ -254,4 +263,13 @@ using DiffRules: diffrule, hasdiffrule
@test (f)(Float64[1,2,3])[1] == Float64[2,4,6]
@test (f; get_output=true)(Float64[1,2,3])[1].val == f(Float64[1,2,3])
end

# fused broadcasting with different styles
let
f(x) = sum(Symmetric(x) .+ 0.0001 .* Diagonal(ones(size(x, 1))))
a = rand(3, 3)
a += transpose(a)
@test (f)(a)[1] == Float64[1 2 2; 0 1 2; 0 0 1]
@test (f; get_output=true)(a)[1].val == f(a)
end
end

0 comments on commit 79b0a98

Please sign in to comment.