When using zygote, I get a type instability only on a gpu version of code. Any pointers on what to do about this?
using Zygote
using CUDA
# Works!
term = rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
# Doesn't Work!
term = CUDA.rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))
IIRC Zygote uses ForwardDiff.jl for broadcasts on the GPU, so that could be the difference you are seeing here. Generally the best way to debug issues like this is to step through the code using Cthulhu.jl.
Last updated: Dec 28 2024 at 04:38 UTC