When using zygote, I get a type instability only on a gpu version of code. Any pointers on what to do about this?
using Zygote
using CUDA
# Works!
term = rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
# Doesn't Work!
term = CUDA.rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
IIRC Zygote uses ForwardDiff.jl for broadcasts on the GPU, so that could be the difference you are seeing here. Generally the best way to debug issues like this is to step through the code using Cthulhu.jl.
Last updated: Mar 04 2025 at 04:41 UTC