I have a deeply recursive structure for which I want to be able to compute gradients with Zygote. The structure is a computation graph where each vertex has references to its inputs.
Problem is that Zygote randomly stalls indefinitely when compiling the backwards pass. I have up until now worked around this by first flattening the graph to a topologically sorted vector inside a @nograd
and then doing the evaluation in a loop. However, I just realized that this approach doesn't work at all for explicit gradients which as far as I understand will soon be the mainstream (only?) way to update parameters.
Old discourse thread with some example code in case it is unclear what I'm talking about: https://discourse.julialang.org/t/improve-performance-of-computation-graph-evaluation/32873
Question 1: Does anyone have an intuition about why Zygote would stall on a recursive structure? I suppose I should open an issue, but I have a hard time coming up with an MWE. The structure purposely uses some abstract and/or untypes fields to allow for mutation. Could that be the problem? Or could it be the opposite, e.g. too much specialization (which afaik also can have a negative impact on compile times) as many fields are typed and the compositional hiearachy is quite deep?
Question 2: Is it madness to try to stich together gradients in the pullback function of an rrule
so that I get a Tangent vector for the flattened graph and then recurse through the graph to create a Tangent structure which matches the graph. This is basically moving all recursion to a pre-step in the primal (flatten the graph) and a post step in the pullback (turn the flat Tangent array to a recursive Tangent structure which matches the structure of the graph) so that AD never sees any recursion. I have a (seemingly) working version of this which didn't turn out too horrible, but maybe it will be a nightmare to maintain.
I guess a third option is to do some getproperty
hacking so that the flat tangent vector is backed by the graph (e.g getproperty(graph, :vertices)
returns the flattened graph), but I don't know how to get ChainRulesCore
to accept this.
It seems like recursion vs looping might be a red herring. I managed to get a model which reproduced the issue 100% for both looping and recursion and for that case the culprit seems to be related to https://github.com/FluxML/Zygote.jl/issues/1111.
When either stripping the graph of any mutable wrappers or defining rrule
s for them manually I get both implicit and explicit gradients for both the looping+stiching variant and the recursive variant. Lets see if it holds up when running more models...
DrChainsaw said:
The structure purposely uses some abstract and/or untypes fields to allow for mutation. Could that be the problem? Or could it be the opposite, e.g. too much specialization (which afaik also can have a negative impact on compile times) as many fields are typed and the compositional hiearachy is quite deep?
I think both of these have come up before. Have you tried tossing the gradient call through PProf/StatProfilerHTML/etc to see where most of the time is spent? If LLVM is showing up a lot, then https://timholy.github.io/SnoopCompile.jl/stable/reference/#SnoopCompileCore.@snoopl may be helpful too.
Mutable structs in Zygote are a massive headache :(. If I had a choice I would remove setfield!
support entirely and wipe out a whole class of bugs + a good amount of internal complexity
@Brian Chen : SnoopCompile works with gradients!? I had no idea!
I'll definitely try it if I run into more issues. Is there a way to make it produce output in the complete stalling case or does it need to complete before any output is generated? When it stalls, not even ctrl+C works. I need to kill the whole process.
SnoopCompile works with pretty much anything! Though note here I've not linked the type inference part of that library.
One thing I've been trying recently is:
@snoopi_deep
@snoopl
Yeah, I understand the issues with mutability and Zygote. In my case the mutable structs are only mutated when doing changes to the graph (i.e changing sizes of parameter arrays or adding/removing vertices or edges), and thankfully I don't need to compute gradients for that. I think I can work around all the cases now that I know this was the culprit.
Unfortunately Zygote will destroy type stability the moment it sees a getfield on a mutable type. Likewise if it sees control flow. I've been thinking about ways to mitigate the damage, but working in the internals is slow going and there are a lot of edge cases to cover (e.g. nested differentiation)
Thats good info. I don't think I'm hurt alot by the type instability, but I guess it doesn't hurt to get rid of it. In many cases I can probably rewrite to a normal struct with some fields being RefValues
(which are also fields which are not touched in the forward pass). Would that be an improvement?
I'm not sure, since Refs get a similar treatment.
Last updated: Nov 06 2024 at 04:40 UTC