Stream: helpdesk (published)

Topic: Explicit gradients when Zygote chokes on recursion


view this post on Zulip DrChainsaw (Jun 10 2022 at 17:32):

I have a deeply recursive structure for which I want to be able to compute gradients with Zygote. The structure is a computation graph where each vertex has references to its inputs.

Problem is that Zygote randomly stalls indefinitely when compiling the backwards pass. I have up until now worked around this by first flattening the graph to a topologically sorted vector inside a @nograd and then doing the evaluation in a loop. However, I just realized that this approach doesn't work at all for explicit gradients which as far as I understand will soon be the mainstream (only?) way to update parameters.

Old discourse thread with some example code in case it is unclear what I'm talking about: https://discourse.julialang.org/t/improve-performance-of-computation-graph-evaluation/32873

Question 1: Does anyone have an intuition about why Zygote would stall on a recursive structure? I suppose I should open an issue, but I have a hard time coming up with an MWE. The structure purposely uses some abstract and/or untypes fields to allow for mutation. Could that be the problem? Or could it be the opposite, e.g. too much specialization (which afaik also can have a negative impact on compile times) as many fields are typed and the compositional hiearachy is quite deep?

Question 2: Is it madness to try to stich together gradients in the pullback function of an rrule so that I get a Tangent vector for the flattened graph and then recurse through the graph to create a Tangent structure which matches the graph. This is basically moving all recursion to a pre-step in the primal (flatten the graph) and a post step in the pullback (turn the flat Tangent array to a recursive Tangent structure which matches the structure of the graph) so that AD never sees any recursion. I have a (seemingly) working version of this which didn't turn out too horrible, but maybe it will be a nightmare to maintain.

I guess a third option is to do some getproperty hacking so that the flat tangent vector is backed by the graph (e.g getproperty(graph, :vertices) returns the flattened graph), but I don't know how to get ChainRulesCore to accept this.

view this post on Zulip DrChainsaw (Jun 10 2022 at 22:30):

It seems like recursion vs looping might be a red herring. I managed to get a model which reproduced the issue 100% for both looping and recursion and for that case the culprit seems to be related to https://github.com/FluxML/Zygote.jl/issues/1111.

When either stripping the graph of any mutable wrappers or defining rrules for them manually I get both implicit and explicit gradients for both the looping+stiching variant and the recursive variant. Lets see if it holds up when running more models...

view this post on Zulip Brian Chen (Jun 10 2022 at 22:30):

DrChainsaw said:

The structure purposely uses some abstract and/or untypes fields to allow for mutation. Could that be the problem? Or could it be the opposite, e.g. too much specialization (which afaik also can have a negative impact on compile times) as many fields are typed and the compositional hiearachy is quite deep?

I think both of these have come up before. Have you tried tossing the gradient call through PProf/StatProfilerHTML/etc to see where most of the time is spent? If LLVM is showing up a lot, then https://timholy.github.io/SnoopCompile.jl/stable/reference/#SnoopCompileCore.@snoopl may be helpful too.

view this post on Zulip Brian Chen (Jun 10 2022 at 22:32):

Mutable structs in Zygote are a massive headache :(. If I had a choice I would remove setfield! support entirely and wipe out a whole class of bugs + a good amount of internal complexity

view this post on Zulip DrChainsaw (Jun 10 2022 at 22:32):

@Brian Chen : SnoopCompile works with gradients!? I had no idea!

I'll definitely try it if I run into more issues. Is there a way to make it produce output in the complete stalling case or does it need to complete before any output is generated? When it stalls, not even ctrl+C works. I need to kill the whole process.

view this post on Zulip Brian Chen (Jun 10 2022 at 22:33):

SnoopCompile works with pretty much anything! Though note here I've not linked the type inference part of that library.

view this post on Zulip Brian Chen (Jun 10 2022 at 22:35):

One thing I've been trying recently is:

  1. get e2e + time inference types with @snoopi_deep
  2. get LLVM time with @snoopl
  3. If inference time is high, that's a type stability issue. If LLVM times are high, that's a codegen + possible overspecialization issue. If neither is high, wow we solved TTFG

view this post on Zulip DrChainsaw (Jun 10 2022 at 22:35):

Yeah, I understand the issues with mutability and Zygote. In my case the mutable structs are only mutated when doing changes to the graph (i.e changing sizes of parameter arrays or adding/removing vertices or edges), and thankfully I don't need to compute gradients for that. I think I can work around all the cases now that I know this was the culprit.

view this post on Zulip Brian Chen (Jun 10 2022 at 22:36):

Unfortunately Zygote will destroy type stability the moment it sees a getfield on a mutable type. Likewise if it sees control flow. I've been thinking about ways to mitigate the damage, but working in the internals is slow going and there are a lot of edge cases to cover (e.g. nested differentiation)

view this post on Zulip DrChainsaw (Jun 10 2022 at 22:45):

Thats good info. I don't think I'm hurt alot by the type instability, but I guess it doesn't hurt to get rid of it. In many cases I can probably rewrite to a normal struct with some fields being RefValues (which are also fields which are not touched in the forward pass). Would that be an improvement?

view this post on Zulip Brian Chen (Jun 11 2022 at 00:06):

I'm not sure, since Refs get a similar treatment.


Last updated: Nov 22 2024 at 04:41 UTC