Stream: helpdesk (published)

Topic: Equivalent of a `scan` function?


view this post on Zulip Brian Chen (May 28 2022 at 16:54):

Copied from Slack and hoping for some more expert opinions here :)

Do we have an equivalent to https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html? accumulate is the closest stdlib I know of, but it doesn't allow for carrying of non-output intermediate states. https://juliafolds.github.io/Transducers.jl/dev/reference/manual/#Transducers.ScanEmit was the best example I could find in the broader ecosystem, but I'm not sure if there's a way to extract uₙ from it.

view this post on Zulip Mason Protter (May 28 2022 at 20:15):

@Brian Chen yeah, the unu_n should be available at least depending on what you're planning to do with them

view this post on Zulip Mason Protter (May 28 2022 at 20:19):

julia> 1:3 |> ScanEmit(-1) do u, x
           @show u
           (x+1, u-1)
       end |> collect
u = -1
u = -2
u = -3
3-element Vector{Int64}:
 2
 3
 4

view this post on Zulip Mason Protter (May 28 2022 at 20:22):

The jax readme was not very enlightening to me, maybe you could give a julian pseudoexample of what you want?

view this post on Zulip Brian Chen (May 28 2022 at 20:33):

It's available but not returned, unfortunately. For comparison JAX's scan would return (-3, [2, 3, 4]).

view this post on Zulip Brian Chen (May 28 2022 at 20:34):

I could try a local var, but closure boxing makes that difficult. A Ref or other mutable struct is a no-go to because Zygote

view this post on Zulip Brian Chen (May 28 2022 at 20:38):

Mason Protter said:

maybe you could give a julian pseudoexample of what you want?

function scan(f, init, xs)
  acc = init
  ys = []
  for x in xs
    acc, y = f(acc, x)
    push!(ys, y)
  end
  return acc, ys
end

view this post on Zulip Brian Chen (May 28 2022 at 20:44):

This is the minimal version. JAX does a bunch of extra stuff with dim handling and SoA inputs, but all I want for Flux/Lux is new_rnn_state, output_seq = scan(rnn_forward, rnn_state, input_seq).

view this post on Zulip Mason Protter (May 28 2022 at 22:37):

Hm, maybe the onlast option can do that?

view this post on Zulip Mason Protter (May 29 2022 at 02:14):

@Takafumi Arakaki (tkf) any ideas?

view this post on Zulip Takafumi Arakaki (tkf) (May 29 2022 at 03:46):

You can do

julia> rf = ScanEmit(-1) do u, x
           @show u
           (x+1, u-1)
       end'(push!!) |> Completing
       acc = Transducers.start(rf, Int[])
       acc = Transducers.foldl_nocomplete(rf, acc, 1:3);
u = -1
u = -2
u = -3

julia> Transducers.unwrap(rf, acc)  # (u, ys)
(-4, [2, 3, 4])

although Transducers.foldl_nocomplete is not documented

Edit: s/append!!/push!!/

view this post on Zulip Brian Chen (May 29 2022 at 04:14):

That works! Now I need to figure out how to make it Zygote-friendly...

view this post on Zulip Takafumi Arakaki (tkf) (May 29 2022 at 04:55):

I'd guess that this is not very Zygote friendly, unfortunately... As this is a very generic concept, why not actually write it by hand and also define the chain rule (i.e., make it effectively a "builtin")?

(The foldl part can be optimized by https://github.com/JuliaFolds/FoldsChainRules.jl but we still have push!!)

view this post on Zulip Brian Chen (May 29 2022 at 04:57):

Was thinking of how to write one :) a function like push!! would make the rule itself trivial (well, trivial for Vectors at least), but nested differentiation makes my head hurt

view this post on Zulip Brian Chen (May 29 2022 at 04:57):

Any chance you think for something like this to be in Base?

view this post on Zulip Brian Chen (May 29 2022 at 05:01):

One advantage jax has is that they enforce inputs/outputs to be scalar numbers, arrays or structures of arrays. That allows for tricks like pre-allocating an output buffer for scanning over the first dimension of a multi-dimensional array.

view this post on Zulip Takafumi Arakaki (tkf) (May 29 2022 at 09:52):

Brian Chen said:

One advantage jax has is that they enforce inputs/outputs to be scalar numbers, arrays or structures of arrays.

Yes, I agree it is an advantage for the compiler/library authors. But it sounds like a flip side of a disadvantage (= limit the kind of program written easily) on the user side to me. In Julia, I think it'd be more natural to let the users express the constraints (output is an array, output shape is known, output element type, etc.) through an API function call. I think it'd be some extension of the "into" PR I opened in julia.

view this post on Zulip Brian Chen (May 29 2022 at 13:52):

I agree 100%, but the perfect is also kind of the enemy of the good here. Given very limited time as a library author, I'm more likely to leave the status quo of bad performance than to try rolling my own little subsystem for this.

view this post on Zulip Brian Chen (May 29 2022 at 14:52):

Also double-checked that I am indeed subscribed to that into PR :big_smile:


Last updated: Nov 06 2024 at 04:40 UTC