Stream: helpdesk (published)

Topic: gradient of divergence in Flux


view this post on Zulip Vasily Ilin (Nov 08 2022 at 00:44):

I am trying to take the gradient of the divergence of a function s:RdRds: R^d \to R^d, but getting an error. Can anyone help me?

using Flux, Zygote, LinearAlgebra

d = 5

s = Chain(
    Dense(d => 10, σ),
    Dense(10 => d))

∇s(x) = sum(diag(jacobian(x -> s(x), x)[1])) # one way of getting divergence
∇s(x) = sum((gradient(x -> s(x)[i], x))[1][i] for i in 1:d) # another way of getting divergence

x = rand(5)

loss_grad = gradient(() -> ∇s(x), Flux.params(s)) #errors

Last updated: Nov 22 2024 at 04:41 UTC