Is there something like all(pred, collection)
but that requires one false
value? I'm thinking that all
is like chaining and
, any
is chaining or, and I want something that's (sort of) chaining xor
. A non-short circuiting version would just be something like
julia> allbutone(f, collection) = length(collection) - sum(f, collection) == 1
allbutone (generic function with 1 method)
julia> allbutone(>(1), 1:10)
true
julia> allbutone(>(1), 0:10)
false
julia> allbutone(>(1), 2:10)
false
I suppose the xor
thing doesn't make sense, since another interpretation would be onlyone
or something...
I'm relatively confident you'll need to just write the short-circuiting version yourself; this seems like a really special-purpose function. The short-circuiting version is only a few lines of code anyway.
The inverse (finding exactly 1 true
) is slightly easier to write, and you can just pass
!f
anyway:
function exactly_n(f, itr, n=1)
count = 0
for x in itr
count += f(x)
count > n && return false
end
return count == n
end
Note that depending on your input data, the version without short-circuiting might be much
faster.
For best performance you probably would want to check count > n
after several iterations rather than every iteration (assuming f
is relatively inexpensive).
(I wonder if Base
should have some sort of generic short-circuiting reduction implementation :thinking: )
Transducers.jl supports early termination through reduced
, which is basically just short circuiting but more general.
If you want something that's not too 'transducery' though, you could write
julia> function short_circuit_reduce(f, op, itr; init, flag)
state = init
for x in itr
state = op(state, f(x))
state == flag && return state
end
state
end;
Then, we just make version of &
which will "forgive" us if we give it false
input n
times:
julia> mutable struct ForgivingAnd
n::Int
end
julia> (fga::ForgivingAnd)(p, q) = p & q || (fga.n -= 1) >= 0
Now, allbutone
just needs to stick a ForgivingAnd(1)
into short_circuit_reduce
and check that it forgave exactly 1
false
:
julia> function allbutone(f, itr)
(&) = ForgivingAnd(1)
short_circuit_reduce(f, &, itr; init=true, flag = false)
(&).n == 0
end
Then, we can check if it's working correctly:
julia> allbutone(>(1), 1:10)
true
julia> allbutone(>(1), 0:10)
false
julia> allbutone(>(1), 2:10)
false
If you want to short circuit out every n
iterations, you could do something like
julia> function allbutone(f, itr; chunk_size=64)
(&) = ForgivingAnd(1)
short_circuit_reduce(f, &, itr; init=true, flag = false, chunk_size=chunk_size)
(&).n == 0
end
julia> function short_circuit_reduce(f, op, itr; init, flag, chunk_size::Int = 64)
state = init
for chunk in Iterators.partition(itr, chunk_size)
for x in chunk
state = op(state, f(x))
end
state == flag && return state
end
state
end
Then we can look at the performance implications on that for functions which have tight-loops like the one @Kevin Bonham was interested in, in the case where it never actually short circuits:
julia> for N in (10, 100, 1000, 10_000)
itr = rand(1:1000, N)
@show N
for chunk_size in 2 .^ (1:2:10)
print(" "); @show chunk_size
print(" "); @btime allbutone(>(0), $itr; chunk_size=$chunk_size)
end
println()
end
N = 10
chunk_size = 2
30.130 ns (1 allocation: 16 bytes)
chunk_size = 8
22.942 ns (1 allocation: 16 bytes)
chunk_size = 32
20.972 ns (1 allocation: 16 bytes)
chunk_size = 128
20.993 ns (1 allocation: 16 bytes)
chunk_size = 512
20.963 ns (1 allocation: 16 bytes)
N = 100
chunk_size = 2
212.223 ns (1 allocation: 16 bytes)
chunk_size = 8
132.495 ns (1 allocation: 16 bytes)
chunk_size = 32
135.200 ns (1 allocation: 16 bytes)
chunk_size = 128
109.536 ns (1 allocation: 16 bytes)
chunk_size = 512
109.535 ns (1 allocation: 16 bytes)
N = 1000
chunk_size = 2
1.976 μs (1 allocation: 16 bytes)
chunk_size = 8
1.168 μs (1 allocation: 16 bytes)
chunk_size = 32
1.263 μs (1 allocation: 16 bytes)
chunk_size = 128
1.004 μs (1 allocation: 16 bytes)
chunk_size = 512
943.423 ns (1 allocation: 16 bytes)
N = 10000
chunk_size = 2
19.669 μs (1 allocation: 16 bytes)
chunk_size = 8
11.559 μs (1 allocation: 16 bytes)
chunk_size = 32
12.510 μs (1 allocation: 16 bytes)
chunk_size = 128
9.859 μs (1 allocation: 16 bytes)
chunk_size = 512
9.330 μs (1 allocation: 16 bytes)
So there's real gains to be had here, likely because SIMD is being enabled.
For a less generic version, you could make something screaming fast with LoopVectorization.jl of course.
Thanks for the thorough explanations! It actually hadn't occurred to me that writing the naive short-circuit version might end up slower.
@Adam non-jedi Beckmeyer my first thought was basically your exactly_n
, but counting falses instead of trues :wink:.
@Mason Protter I think I'm going to have to sit down and study your implementations with a cup of coffee :big_smile:. I understand how to use iterators, but the writing of them is still not intuitive.
Last updated: Nov 22 2024 at 04:41 UTC