Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an error for broadcasting with CUDA + complex numbers, etc #1225

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

Xref #1215

src/lib/broadcast.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott marked this pull request as draft May 13, 2022 12:04
@jgreener64
Copy link
Contributor

I think making this an error is a good idea. Ideally there would be a way to not error for custom types when you know it is okay to not track the gradients. I guess you can define _dual_safearg for your custom type but this might be worth describing in the error message, or exposing in a different way.

@mcabbott
Copy link
Member Author

Yes. I guess the thing you overload should eventually be something like JuliaDiff/ChainRulesCore.jl#528

However, at the moment I believe you get errors from unbroadcast not having appropriate methods, if you try to use some weird type (even just a Symbol, IIRC). So it ought to be safe to make these deliberate errors now, and adjustable later.

@mcabbott mcabbott added the CUDA All things GPU label Jul 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CUDA All things GPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants