Copied from Slack and hoping for some more expert opinions here :)
Do we have an equivalent to https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html? accumulate
is the closest stdlib I know of, but it doesn't allow for carrying of non-output intermediate states. https://juliafolds.github.io/Transducers.jl/dev/reference/manual/#Transducers.ScanEmit was the best example I could find in the broader ecosystem, but I'm not sure if there's a way to extract uₙ from it.
@Brian Chen yeah, the should be available at least depending on what you're planning to do with them
julia> 1:3 |> ScanEmit(-1) do u, x
@show u
(x+1, u-1)
end |> collect
u = -1
u = -2
u = -3
3-element Vector{Int64}:
2
3
4
The jax readme was not very enlightening to me, maybe you could give a julian pseudoexample of what you want?
It's available but not returned, unfortunately. For comparison JAX's scan would return (-3, [2, 3, 4])
.
I could try a local var, but closure boxing makes that difficult. A Ref or other mutable struct is a no-go to because Zygote
Mason Protter said:
maybe you could give a julian pseudoexample of what you want?
function scan(f, init, xs)
acc = init
ys = []
for x in xs
acc, y = f(acc, x)
push!(ys, y)
end
return acc, ys
end
This is the minimal version. JAX does a bunch of extra stuff with dim handling and SoA inputs, but all I want for Flux/Lux is new_rnn_state, output_seq = scan(rnn_forward, rnn_state, input_seq)
.
Hm, maybe the onlast
option can do that?
@Takafumi Arakaki (tkf) any ideas?
You can do
julia> rf = ScanEmit(-1) do u, x
@show u
(x+1, u-1)
end'(push!!) |> Completing
acc = Transducers.start(rf, Int[])
acc = Transducers.foldl_nocomplete(rf, acc, 1:3);
u = -1
u = -2
u = -3
julia> Transducers.unwrap(rf, acc) # (u, ys)
(-4, [2, 3, 4])
although Transducers.foldl_nocomplete
is not documented
Edit: s/append!!/push!!/
That works! Now I need to figure out how to make it Zygote-friendly...
I'd guess that this is not very Zygote friendly, unfortunately... As this is a very generic concept, why not actually write it by hand and also define the chain rule (i.e., make it effectively a "builtin")?
(The foldl part can be optimized by https://github.com/JuliaFolds/FoldsChainRules.jl but we still have push!!
)
Was thinking of how to write one :) a function like push!!
would make the rule itself trivial (well, trivial for Vector
s at least), but nested differentiation makes my head hurt
Any chance you think for something like this to be in Base?
One advantage jax has is that they enforce inputs/outputs to be scalar numbers, arrays or structures of arrays. That allows for tricks like pre-allocating an output buffer for scanning over the first dimension of a multi-dimensional array.
Brian Chen said:
One advantage jax has is that they enforce inputs/outputs to be scalar numbers, arrays or structures of arrays.
Yes, I agree it is an advantage for the compiler/library authors. But it sounds like a flip side of a disadvantage (= limit the kind of program written easily) on the user side to me. In Julia, I think it'd be more natural to let the users express the constraints (output is an array, output shape is known, output element type, etc.) through an API function call. I think it'd be some extension of the "into" PR I opened in julia
.
I agree 100%, but the perfect is also kind of the enemy of the good here. Given very limited time as a library author, I'm more likely to leave the status quo of bad performance than to try rolling my own little subsystem for this.
Also double-checked that I am indeed subscribed to that into PR :big_smile:
Last updated: Nov 06 2024 at 04:40 UTC