Is this a bug?
julia> using Zygote
julia> f(x, y) = y
f (generic function with 1 method)
julia> gradient(f, 0, 0)
(nothing, 1.0)
Shouldn't the result be (0.0, 1.0)
instead?
Reported here: https://github.com/FluxML/Zygote.jl/issues/1538
Zygote uses nothing
as a "hard" zero
i.e. a differential that's known at compile time to be zero is represented as nothing
.
That is somewhat unexpected mathematically speaking. Makes it difficult to write generic code. Is there a good practice to handle nothing in this context ?
I guess you could do something like
denothing(x) = x
denothing(::Nothing) = false
my_gradient(args...; kwargs...) = denothing.(gradient(args...; kwargs...))
julia> sum(my_gradient(f, 0, 0))
1.0
I wonder why this is not done automatically for end-users
Shouldn't it be fixed in Zygote.jl?
One reason for a special flag is that Zygote can avoid some work in the backward pass, as the gradient of any operations done before f
is certain to be zero. Whereas with a runtime 0.0
it can't tell & must do the work.
The other is that for larger things like x::Array
, allocating zero(x)
is expensive.
FWIW, the Enzyme's design is:
julia> Enzyme.gradient(Reverse, (x,y) -> sum(abs2, x .* y), [1, 2.], [3 4 5.])
([100.0, 200.0], [30.0 40.0 50.0])
julia> Enzyme.gradient(Reverse, (x,y) -> sum(abs2, x .* x), [1, 2.], [3 4 5.])
([4.0, 32.0], [0.0 0.0 0.0])
julia> Enzyme.gradient(Reverse, (x,y) -> sum(abs2, x .* y), [1, 2.], Const([3 4 5.]))
([100.0, 200.0], nothing)
It could have had a design like ChainRulesCore.ZeroTangent()
, since that at least supports math ops
but yeah, it's mostly just historical reasons and a ton of work to overhaul it
FWIW, the Enzyme's design is:
Enzyme's gradient
is pretty unlike Zygote's gradient.
I'd say the equivalent in Enzyme is instead
julia> let x = Ref(0.0), y = Ref(0.0)
dx, dy = make_zero(x), make_zero(y)
autodiff(Reverse, Duplicated(x, dx), Duplicated(y, dy)) do x, y
f(x[], y[])
end
dx[], dy[]
end
(0.0, 1.0)
I'm considering moving to Enzyme.jl because of this design of Zygote.jl. It is pretty counter intuitive to have a mathematical gradient with those entries
Does Enzyme.jl support all platforms that Julia supports?
I understand it is a wrapper package
And another question: is Zygote.jl the recommended package for autodiff in native Julia or there is something new?
https://discourse.julialang.org/t/state-of-ad-in-2024/112601
Can you say what problem nothing
causes, more narrowly than just being surprising?
I think he wants to be able to do math with the result of gradient
.
For instance, I agree that the fact that x + dx
won't always work is a bit sad. (I think + needs to be replaced with Zygote.accum
which knows about nothing
.) ChainRules.jl took making this work as an axiom, and the result was massive complexity of Tangent
which has all kinds of sharp edges. (Not to mention several kinds of zeros which nobody knows how to use correctly, and resulting type-instabilities.) So there are trade-offs, and nothing
(plus NamedTuple for any struct) has the advantage of being very simple.
Michael Abbott said:
Can you say what problem
nothing
causes, more narrowly than just being surprising?
We are simply doing Newton-Rhapson iteration with automatic gradients. The problem with this nothing
design is that it relies on all third-party packages handling it. Even if we workaround the situation in our own package, this solution doesn't compose well.
Wouldn't Enzyme be a much better fit for stuff like Newton Rhapson because it supports mutation?
We are doing Newton-Rhapson with 2 scalar values. There are no allocations.
Ah. In that case, maybe just use ForwardDiff?
Will take a look. I am assuming that ForwardDiff.jl provides autodiff like Zygote.jl but without the nothing
.
You really only want to reach for reverse mode AD like Zygote if you need the derivatives of functions from N
dimensions to M
dimensions where N >> M
And in terms of maturity, ForwardDiff.jl is mature, actively maintained, etc?
The classic use-case for reverse-mode is deep learning where N
might be in the many thousands and M = 1
ForwardDiff is very mature.
I'd say it's actively maintained, but I wouldn't say it's actively developed (on account of said maturity)
I already like that it has a much smaller list of dependencies compared to Zygote.jl
forward mode AD is just fundamentally much much much simpler than reverse mode
If you feel like trying out something bleeding edge instead, Diffractor.jl actually has a pretty well working forwards mode nowadays (probably don't actually do this)
As a general rule, you should avoid reverse mode like the plague unless you are absolutely sure you need it.
Thank you. That is very helpful.
Also, since it hasn't been mentioned yet, I highly recommend using DifferentiationInterface.jl which makes it trivial to swap out AD back-ends and has no performance penalty in simple cases.
It'd be nice if DI turned the nothing
s into some sort of zero <:Number
.
I think in this case we will go ahead with the ForwardDiff.jl package directly. There are no plans to swap the backend given that it is ideal for the application at hand.
For some reason ForwardDiff.jl is generating slower code compared to Zygote.jl.
Can you try to reproduce this benchmark on the main
branch (Zygote) and on the forwarddiff
branch?
https://github.com/JuliaEarth/CoordRefSystems.jl/tree/main/benchmark
Do you also see a massive slow down in the last line of the output.csv? The last column has the speedup metric.
For me the Zygote.jl result is 0.28 and the ForwardDiff.jl result is 0.06 (larger is better).
We are simply doing Newton-Rhapson iteration with automatic gradients.
If it's scalar, then you don't want to be diffing through it anyways. BracketingNonlinearSolve or SimpleNonlinearSolve with Zygote/ForwardDiff overloads would just skip the implicit part.
But I would almost guarantee for scalar that ForwardDiff will be faster here.
With forward mode you want to essentially always do this trick: https://github.com/SciML/NonlinearSolve.jl/blob/master/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Thank you @Christopher Rackauckas . Can you please elaborate on that?
The PR that replaces Zygote.jl by ForwardDiff.jl has a small diff that you can read here: https://github.com/JuliaEarth/CoordRefSystems.jl/pull/199/files
What do we need to do differently to get the expected superior performance of forward diff?
How are you solving the nonlinear system?
The diff has the formulas. Basically given two functions and and values and , we perform newton iteration to find and .
These two formulas as decoupled in the diff above as you can see.
Yeah so if it's using SimpleNonlinearSolve it should automatically apply the implicit rule
If you did it by hand then you'll need to copy that code / do a similar implicit function push through on the duals
For scalar it's almost equivalent to not differentiating the first n steps of the newton method, re-applying the duals, and then applying it on the n+1th step
I am not sure I am following. As an end-user of ForwardDiff.jl it is not clear what I am doing wrong.
Optimally handling implicit equations is not something automatic differentiation as a tool can do on its own. It requires that the solver library that you're using for the implicit system overloads the AD to avoid differentiation through the method
So Zygote.jl is doing something more that guarantees better performance?
The derivative of Newton-Rhapson w.r.t. u0
is 0, since the solution is independent of the initial condition (or undefined if it moves to a different solution). So you need to not differentiate the solve and then only differentiate effectively the last step. If the implicit solve is the expensive part of the code, then doing this trick turns O(n) expensive calls differentiating each step into exactly 1. That's hard to beat.
It's not really up to the AD libraries. It's up to the solver libraries, i.e. whomever writes the Newton method (NonlinearSolve) to supply rules for ForwardDiff/Zygote/etc. to do this
You mean that there is a small package that we could take as dependency that already defines newton-rhapson inversion with AD rules?
Since an AD library cannot really know by looking at code that it should have this convergence property, i.e. that the solution is independent of the previous steps, not in code (since in the code, each step of newton depends on the previous step), but in the solution (since it converges to the same value regardless of where you start)
It's a split out of NonlinearSolve that is focused only on very simple Newton Rhapson + the required AD rules.
It looks like the list of dependencies is very large?
Ideally, we would just retain the performance of Zygote.jl, but with ForwardDiff.jl
Well with a scalar nonlinear solve you probably want to be using ITP instead of Newton for stability if you have bounds. In that case, BracketingNonlinearSolve would then be an even smaller dep.
What exactly do you need different in the size? The import time is ~200ms and most of that is the precompilation load of the Newton method itself.
It beggars belief that the code in the diff you pasted is much slower in forwarddiff than in zygote, though of course I don't know what functions you are running through it. I think there is something else wrong.
Most likely yes, Zygote shouldn't ever be faster in this kind of case.
But even then, the next thing you'd want to do is do the implicit rule for either ForwardDiff or Zygote :shrug:
Perhaps I didn't run the benchmark properly. Let me try to isolate the issue.
using SimpleNonlinearSolve
f(u, p::Number) = u * u - p
f(u, p::Vector) = u * u - p[1]
u0 = 1.0
p = 1.0
const cprob = NonlinearProblem(f, u0, p)
sol = solve(prob_int, SimpleNewtonRaphson())
function loss(p)
solve(remake(cprob, p=p),SimpleNewtonRaphson()).u - 4.0
end
using ForwardDiff, BenchmarkTools
@btime ForwardDiff.derivative(loss, p)
16.741 ns (1 allocation: 16 bytes)
For a scalar problem you should be able to optimize most stuff out of it.
Though a bracketing method is almost certainly going to be more robust
using BracketingNonlinearSolve
f(u, p::Number) = u * u - p
u0 = 1.0
p = 1.0
uspan = (1.0, 2.0) # brackets
const cprob_int = IntervalNonlinearProblem(f, uspan, p)
sol = solve(prob_int)
function loss(p)
solve(remake(cprob_int, p=p)).u - 4.0
end
using ForwardDiff, BenchmarkTools
@btime ForwardDiff.derivative(loss, p);
18.495 ns (1 allocation: 16 bytes)
You can probably specialize on a lot of other properties too though. What kind of system is it? Is it polynomial? Rational polynomial?
I am creating a MWE with the exact code that is slower. Will share here in a few minutes...
using CoordRefSystems
using BenchmarkTools
latlon = LatLon(45, 90)
winkel = convert(WinkelTripel, latlon)
@btime convert($LatLon, $winkel)
1.491 μs (10 allocations: 192 bytes) # main
6.356 μs (144 allocations: 2.88 KiB) # PR
You can see that the ForwardDiff.jl in the PR is 6x slower. The underlying functions fx
and fx
are here:
Trigonometric functions.
what is sincα?
oh I see
defined right above
Wait you're talking about AD in the nonlinear solve not of the nonlinear solve?
Yes, the AD is in the functions fx
and fy
inside the nonlinear solve.
I was assuming that this should be instantaneous given the "simplicity" of these trigonometric functions.
so where is your forwarddiff code?
These are functions
My guess is you did something odd to handle the multiple returns.
I think this kind of scalar, branch-free straight line code is the best-case performance scenario for Zygote. So it's not crazy that it'd be faster than ForwardDiff.
In this PR I shared a few messages ago: https://github.com/JuliaEarth/CoordRefSystems.jl/pull/199
The PR literally replaces Zygote by ForwardDiff, nothing else.
yeah this kind of case is not so bad for Zygote, though either should do fine
you shouldn't be getting so many allocs with forwarddiff though
but for this kind of case, AD inside the ODE for a scalar output, Zygote should just optimize out all allocs which is usually what would kill it
So Zygote should be fine, and should almost even match Enzyme here without some Reactant tricks.
The only other thing to try really is just avoiding the AD with something like an ITP and seeing how that does.
So the moral of the story is Zygote.jl is still recommended even in this scalar case with N=2 and M=1
in this case, yes, because it can compile away a bunch of stuff so its normal issues don't come up here.
there are cases for which that is not true
it's somewhat code dependent
Zygote sucks at optimizing code with arrays and falls off a cliff any time there's a branch, but yes this is one of the few niches it's perf-competitive in.
Hence the demos Mike and others used to do where they showed it constant-folding all the way to the correct gradient
These heuristics to pick an AD backend are super hard. Every time we dive into it, we unlearn something that was told.
It still bothers me the original issue of this thread where Zygote.jl returns nothing
. That is really annoying.
I think it's better to have a hard zero? It's annoying with AD just treats structural zeros as 0.0
because then it's harder to debug.
For your case you could just x === nothing ? 0.0 : x
A lot of inputs Zygote accepts are not conducive to having natural Zeros. Structs with arbitrary type constraints, for example
though almost certainly if you get that nothing in your code, it's likely a bug and you should throw an error saying "you likely have a bug in your f
"
One challenge ChainRules and later Mooncake ran into is that some types can't even be reliably represented by structural zeroes! Self-referential structs being a big culprit
It is not a bug in f
. It is common to have formulas that only depend on a subset of the arguments in this context.
yeah but that's a general case
That is not grounded in this specific case
in this specific case, if you get nothing
, that means f
is not a function of the parameter
that means you can just remove it from the rootfind
that tell you that you can optimize it more!
I think in an alternate world where ChainRules matured a little earlier, Zygote could've used ZeroTangent
and NoTangent
instead of nothing
That is a good point. Maybe refactoring the algorithm with a branch that handles nothing
is not that bad. In any case, I wish we had Enzyme.jl behavior here, it always returns 0.0 for zero gradient.
Regardless though, this code should want the nothing
or whatever structural zero because then it should just branch down to doing a scalar rootfind and double its speed
This code should also be compatible with Enzyme?
It is. It is just that we are trying to keep it native Julia as much as possible, at least for now. Maybe we will consider Enzyme.jl as the only exception.
Exploiting the structural zero with Zygote would still beat Enzyme here though
The full stack is native Julia, which facilitates deployment in exotic platforms.
Screenshot 2024-11-08 at 6.36.36 PM.png
Chopping out the fy gradient could be like half of the compute, so I'd just exploit the nothing and call it a day.
The full stack is native Julia, which facilitates deployment in exotic platforms.
Like what?
Christopher Rackauckas said:
Chopping out the fy gradient could be like half of the compute, so I'd just exploit the nothing and call it a day.
Yes, it sounds reasonable.
The other thing you could potentially do is use fastmath approximations to the trig functions in the gradient context.
Or run this as a mixed precision and just do the gradient in 32-bit
Christopher Rackauckas said:
The full stack is native Julia, which facilitates deployment in exotic platforms.
Like what?
We are investigating some heterogeneous cluster setups. I understand that external binary dependencies may support a subset of the platforms that Julia supports.
So we avoid external binary deps as much as possible. What is the situation with Enzyme.jl? Does it support all platforms that Julia does because it is LLVM-based?
like how exotic though, ARMv7/8? Or like, embedded type chips?
Julia doesn't even support all LLVM supported platforms because of runtime things
Christopher Rackauckas said:
like how exotic though, ARMv7/8? Or like, embedded type chips?
Nothing specific at the moment. We are just trying to save ourselves from build issues that we can't address easily.
Christopher Rackauckas said:
Julia doesn't even support all LLVM supported platforms because of runtime things
So adding Enzyme.jl as a dependency shouldn't reduce the list of supported platforms, right?
Premature optimization can be the root of all evil.
I mean, it might be easier to get Julia to kick something out for like a TI C600 without Enzyme, but the chances that will ever be in a cluster is zero.
In this case, I see it as precaution. If we can stick to a native Julia app, why not? :smile:
If Enzyme.jl is indeed the best thing to adopt, and the benefits outweigh the downsides, we will go for it.
I mean, I see eVTOLs and satellites deploying to ARMv8 these days. I would be surprised if your case is actually all that exotic unless it's for a microsat
Last updated: Dec 28 2024 at 04:38 UTC