When using zygote, I get a type instability only on a gpu version of code. Any pointers on what to do about this?
using Zygote
using CUDA
# Works!
term = rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
# Doesn't Work!
term = CUDA.rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
IIRC Zygote uses ForwardDiff.jl for broadcasts on the GPU, so that could be the difference you are seeing here. Generally the best way to debug issues like this is to step through the code using Cthulhu.jl.
Last updated: Jan 29 2025 at 04:38 UTC