When using zygote, I get a type instability only on a gpu version of code. Any pointers on what to do about this?
using Zygote
using CUDA
# Works!
term = rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
# Doesn't Work!
term = CUDA.rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
IIRC Zygote uses ForwardDiff.jl for broadcasts on the GPU, so that could be the difference you are seeing here. Generally the best way to debug issues like this is to step through the code using Cthulhu.jl.
Last updated: Nov 06 2024 at 04:40 UTC