Stream: helpdesk (published)

Topic: Type Instability for GPU only using zygote


view this post on Zulip Jimmie Butler (May 31 2021 at 12:02):

When using zygote, I get a type instability only on a gpu version of code. Any pointers on what to do about this?

using Zygote
using CUDA

# Works!
term = rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))

# Doesn't Work!
term = CUDA.rand(Float32, 100)
gradient(x::Float32 -> sum(ifelse.(term .> 0, term .+ x, term)), Float32(5.))

view this post on Zulip Simeon Schaub (May 31 2021 at 18:47):

IIRC Zygote uses ForwardDiff.jl for broadcasts on the GPU, so that could be the difference you are seeing here. Generally the best way to debug issues like this is to step through the code using Cthulhu.jl.


Last updated: Dec 28 2024 at 04:38 UTC