Copied from Slack and hoping for some more expert opinions here :)
Do we have an equivalent to https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html? accumulate is the closest stdlib I know of, but it doesn't allow for carrying of non-output intermediate states. https://juliafolds.github.io/Transducers.jl/dev/reference/manual/#Transducers.ScanEmit was the best example I could find in the broader ecosystem, but I'm not sure if there's a way to extract uₙ from it.
@Brian Chen yeah, the should be available at least depending on what you're planning to do with them
julia> 1:3 |> ScanEmit(-1) do u, x
@show u
(x+1, u-1)
end |> collect
u = -1
u = -2
u = -3
3-element Vector{Int64}:
2
3
4
The jax readme was not very enlightening to me, maybe you could give a julian pseudoexample of what you want?
It's available but not returned, unfortunately. For comparison JAX's scan would return (-3, [2, 3, 4]).
I could try a local var, but closure boxing makes that difficult. A Ref or other mutable struct is a no-go to because Zygote
Mason Protter said:
maybe you could give a julian pseudoexample of what you want?
function scan(f, init, xs)
acc = init
ys = []
for x in xs
acc, y = f(acc, x)
push!(ys, y)
end
return acc, ys
end
This is the minimal version. JAX does a bunch of extra stuff with dim handling and SoA inputs, but all I want for Flux/Lux is new_rnn_state, output_seq = scan(rnn_forward, rnn_state, input_seq).
Hm, maybe the onlast option can do that?
@Takafumi Arakaki (tkf) any ideas?
You can do
julia> rf = ScanEmit(-1) do u, x
@show u
(x+1, u-1)
end'(push!!) |> Completing
acc = Transducers.start(rf, Int[])
acc = Transducers.foldl_nocomplete(rf, acc, 1:3);
u = -1
u = -2
u = -3
julia> Transducers.unwrap(rf, acc) # (u, ys)
(-4, [2, 3, 4])
although Transducers.foldl_nocomplete is not documented
Edit: s/append!!/push!!/
That works! Now I need to figure out how to make it Zygote-friendly...
I'd guess that this is not very Zygote friendly, unfortunately... As this is a very generic concept, why not actually write it by hand and also define the chain rule (i.e., make it effectively a "builtin")?
(The foldl part can be optimized by https://github.com/JuliaFolds/FoldsChainRules.jl but we still have push!!)
Was thinking of how to write one :) a function like push!! would make the rule itself trivial (well, trivial for Vectors at least), but nested differentiation makes my head hurt
Any chance you think for something like this to be in Base?
One advantage jax has is that they enforce inputs/outputs to be scalar numbers, arrays or structures of arrays. That allows for tricks like pre-allocating an output buffer for scanning over the first dimension of a multi-dimensional array.
Brian Chen said:
One advantage jax has is that they enforce inputs/outputs to be scalar numbers, arrays or structures of arrays.
Yes, I agree it is an advantage for the compiler/library authors. But it sounds like a flip side of a disadvantage (= limit the kind of program written easily) on the user side to me. In Julia, I think it'd be more natural to let the users express the constraints (output is an array, output shape is known, output element type, etc.) through an API function call. I think it'd be some extension of the "into" PR I opened in julia.
I agree 100%, but the perfect is also kind of the enemy of the good here. Given very limited time as a library author, I'm more likely to leave the status quo of bad performance than to try rolling my own little subsystem for this.
Also double-checked that I am indeed subscribed to that into PR :big_smile:
Last updated: Oct 26 2025 at 04:40 UTC