I am trying to take the gradient of the divergence of a function , but getting an error. Can anyone help me?
using Flux, Zygote, LinearAlgebra
d = 5
s = Chain(
Dense(d => 10, σ),
Dense(10 => d))
∇s(x) = sum(diag(jacobian(x -> s(x), x)[1])) # one way of getting divergence
∇s(x) = sum((gradient(x -> s(x)[i], x))[1][i] for i in 1:d) # another way of getting divergence
x = rand(5)
loss_grad = gradient(() -> ∇s(x), Flux.params(s)) #errors
Last updated: Nov 22 2024 at 04:41 UTC