Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast + Zygote: surprising types from Int32 input #1445

Open
jeremiahpslewis opened this issue Aug 8, 2023 · 2 comments
Open

Broadcast + Zygote: surprising types from Int32 input #1445

jeremiahpslewis opened this issue Aug 8, 2023 · 2 comments
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@jeremiahpslewis
Copy link

jeremiahpslewis commented Aug 8, 2023

For a function with Float32 return type, the gradient type returned is Vector{Float64} for the broadcasted function, but Float32 for the non-broadcast version.

using Zygote

c = Int32[1, 2, 3]

function f3(a)
    b = Float32[1, 2, 3]
    sum(a .* b)
end

f3(c) # Float32 return type

f3'(c) # Vector{Float64} gradient -> desirable would be Float32

First identified here: JuliaGPU/GPUArrays.jl#484

Non-broadcast version, which works as expected:

using Zygote

c = 3f0

function f3(a)
    b = Int64(1)
    sum(a .* b)
end

f3(c) # Float32 return type

f3'(c) # Float32
@jeremiahpslewis jeremiahpslewis changed the title Float32 function return type, but Vector{Float64} gradient type (breaks some GPU code due to type promotion to unsupported type) Broadcast + Zygote = Weird Type Instability? Aug 8, 2023
@ToucheSir
Copy link
Member

This is not a type instability issue (the code is 100% type stable), but a weird edge case of how promotion rules work in Julia that I wasn't aware of. In brief, the scalar version relies on promote and by extension promote_rule, which stipulates that that <floating point type> plus <int type> -> <floating point type>. In contrast, the array version uses the ProjectTo machinery in ChainRulesCore to ensure correct types are maintained for AD. This calls float(Int32), which ends up returning Float64 for all core (U)Int types!

Here are a couple ideas for addressing this, ranked in order of difficulty:

  1. Integer types technically aren't considered differentiable by many people, so converting your Ints to floats and differentiating wrt them would be less likely to hit any unwritten behaviour there.
  2. You could ask over on the ChainRulesCore side whether the promotion rules could be tweaked. I notice there are currently no tests for ProjectTo(::Int32)(::Float32), so it may just be a missed edge case.
  3. Zygote could add custom paths to its internal projection machinery (which currently mostly wraps ChainRulesCore's) to handle cases like these. I rank this most difficult because we'd likely be reinventing the wheel in a less well-maintained and non-reusable way.

@jeremiahpslewis
Copy link
Author

Thanks for the very clear explanation! I'll stick with 1 for now, the other two are, as you suggest, relatively complex.

@mcabbott mcabbott changed the title Broadcast + Zygote = Weird Type Instability? Broadcast + Zygote: surprising types from Int32 input Aug 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

2 participants