You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For a function with Float32 return type, the gradient type returned is Vector{Float64} for the broadcasted function, but Float32 for the non-broadcast version.
using Zygote
c = Int32[1, 2, 3]
functionf3(a)
b = Float32[1, 2, 3]
sum(a .* b)
endf3(c) # Float32 return type
f3'(c) # Vector{Float64} gradient -> desirable would be Float32
using Zygote
c =3f0functionf3(a)
b =Int64(1)
sum(a .* b)
endf3(c) # Float32 return type
f3'(c) # Float32
The text was updated successfully, but these errors were encountered:
jeremiahpslewis
changed the title
Float32 function return type, but Vector{Float64} gradient type (breaks some GPU code due to type promotion to unsupported type)
Broadcast + Zygote = Weird Type Instability?
Aug 8, 2023
This is not a type instability issue (the code is 100% type stable), but a weird edge case of how promotion rules work in Julia that I wasn't aware of. In brief, the scalar version relies on promote and by extension promote_rule, which stipulates that that <floating point type> plus <int type> -> <floating point type>. In contrast, the array version uses the ProjectTo machinery in ChainRulesCore to ensure correct types are maintained for AD. This calls float(Int32), which ends up returning Float64 for all core (U)Int types!
Here are a couple ideas for addressing this, ranked in order of difficulty:
Integer types technically aren't considered differentiable by many people, so converting your Ints to floats and differentiating wrt them would be less likely to hit any unwritten behaviour there.
You could ask over on the ChainRulesCore side whether the promotion rules could be tweaked. I notice there are currently no tests for ProjectTo(::Int32)(::Float32), so it may just be a missed edge case.
Zygote could add custom paths to its internal projection machinery (which currently mostly wraps ChainRulesCore's) to handle cases like these. I rank this most difficult because we'd likely be reinventing the wheel in a less well-maintained and non-reusable way.
For a function with Float32 return type, the gradient type returned is Vector{Float64} for the broadcasted function, but Float32 for the non-broadcast version.
First identified here: JuliaGPU/GPUArrays.jl#484
Non-broadcast version, which works as expected:
The text was updated successfully, but these errors were encountered: