Can anyone help me with a batch scheduling problem?
Description:
I have a function FOO that I want to run on multiple threads at once. While FOO is running, it will occasionally call another function BAR with an argument unique to each thread that FOO is running on. I want BAR to batch process whatever FOO passes to it (say 32 at a time), and return the values to each unique thread that is running FOO. Is there a way to do this using the Task Library or Channel Library?
If the description of foo and bar is too abstract, I essentially have threads running that each want to classify images, but I want to batch the images before sending them to the GPU to be classified.
Does FOO have to wait for BAR? If so, what's different from a simple blocking call after accumulating 32 elements?
But, since you use GPU, maybe you want to pipeline several calls of BAR from each FOO?
Takafumi Arakaki (tkf) said:
Does FOO have to wait for BAR? If so, what's different from a simple blocking call after accumulating 32 elements?
But, since you use GPU, maybe you want to pipeline several calls of BAR from each FOO?
Yes, FOO needs to wait for BAR before it can continue. I essentially am just trying to figure out how to batch the calls, but ensure that each output from BAR goes to the exact FOO that called it. For example, I thought one way to do this might be create a tuples of the thread_id and the image, classify 32 at a time, and then let the worker threads take their respective images that they put in the channel for classification. But I was interested to know what the best mechanism for this type of idea is.
So, are 32 items coming from different tasks executing FOO
? i.e., you want to combine 32 calls from different tasks to issue one GPU kernel call?
...and want to make sure the result of the call goes back to the correct FOO
Yes. Essentially, let’s say the
Thread 1 calls G with arg 1
Thread 2 calls G with arg 2
Etc etc
Thread n calls G with arg n
I want G (the gpu) to operate on argument 1,2...n in batches of 32 while making sure the result of g on arg1 returns to thread 1, the result of arg2 returns to thread 2, etc.
OK, I now understand the question. You essentially need a promise (future). The easiest way to do this in Julia is to use single element "one-shot" Channel
(I think Go programmers use it all the time).
function batch_processor(f, request::Channel)
batchsize = 32
# It's better to use two arrays but for now this is an array of 2-tuples :
batch = sizehint!(Vector{eltype(Channel)}(undef, 0), batchsize)
nitems = 0
for r in request
push!(batch, r)
nitems += 1
if nitems == batchsize
xs = first.(batch)
ys = f(xs)
promises = last.(batch)
foreach(put!, promises, ys)
empty!(batch)
nitems = 0
end
end
end
function batched_call(request::Channel, x)
promise = Channel(1)
# promise = Channel{TypeOfY}(1) # if you know the result type of f
put!(request, (x, promise))
return take!(promise)
end
function FOO(request::Channel)
...
y = batched_call(request, x)
...
end
request = Channel()
@sync try
# request = Channel(32) # maybe better if it's buffered; tuning needed
@async try
batch_processor(BAR, request)
finally
close(request)
end
for _ in 1:ntasks
Threads.@@spawn try
FOO(request)
finally
close(request)
end
end
finally
close(request)
end
I haven't done any extensive benchmarks, but here is a more direct implementation of promise than Channel(1)
https://github.com/JuliaFolds/FoldsThreads.jl/blob/08329298b73c054e1552f7f4a358e2cdcaf89027/src/utils.jl#L42
but, as you can see in the above code, the tedious part is more about how to use promises
(btw, the above code is totally untested. there can be typos etc.)
An interesting point is that it would be a useful pattern to use ntasks
larger than nthraeds()
or the number of CPUs. The scheduler will suspend the task waiting on take!(promise)
and switch to a new task, executing on a shared worker pool on OS threads.
(I just noticed that batch_processor
might need try
-finally
around the for
loop to close all pending promises, to ensure proper shutdown.)
Thank you so much! This was incredibly helpful. I will take a look into Go style programming.
You are welcome :) yeah, I think Go style is a good inspiration for doing this kind of things
I was just tweaking my above post since I just keep noticing some gotchas :)
function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
batchsize = 32
xs = sizehint!(Vector{X}(undef, 0), batchsize)
promises = sizehint!(Vector{Promise}(undef, 0), batchsize)
nitems = 0
try
for (x, p) in request
push!(xs, x)
push!(promises, p)
nitems += 1
if nitems == batchsize
foreach(put!, promises, f(xs))
empty!(xs)
empty!(promises)
nitems = 0
end
end
if !isempty(batch)
foreach(put!, promises, f(xs))
end
catch
foreach(close, promises)
rethrow()
end
end
function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
promise = Channel{Y}(1)
put!(request, (x, promise))
return take!(promise)
end
function FOO(request::Channel)
...
y = batched_call(request, x)
...
end
X = Y = Any # or more concrete, if known
request = Channel{Tuple{X,Channel{Y}}}()
# request = Channel{Tuple{X,Channel{Y}}}(32) # maybe better if it's buffered; tuning needed
@sync try
@async try
batch_processor(BAR, request)
finally
close(request)
end
for _ in 1:ntasks
Threads.@@spawn try
FOO(request)
catch
close(request)
rethrow()
end
end
finally
close(request)
end
It seems really close to working but is throwing an error message as follows:
TaskFailedException
nested task error: InvalidStateException("Channel is closed.", :closed)
Stacktrace:
[1] check_channel_state
@ ./channels.jl:170 [inlined]
[2] put!
@ ./channels.jl:314 [inlined]
[3] batched_call(request::Channel{Tuple{Any, Channel{Any}}}, x::Int64)
@ Main ./In[22]:35
[4] FOO(request::Channel{Tuple{Any, Channel{Any}}})
@ Main ./In[22]:41
[5] (::var"#38#40")()
@ Main ./threadingconstructs.jl:169
...and 99 more exceptions.
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:369
[2] top-level scope
@ task.jl:388
[3] eval
@ ./boot.jl:360 [inlined]
[4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1094
This is the code I used
function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
batchsize = 32
# list of images to classify
# support 32 elements, but if less, don't give empty elements back
xs = sizehint!(Vector{X}(undef, 0), batchsize)
promises = sizehint!(Vector{Promise}(undef, 0), batchsize)
# classify the images in batches
nitems = 0
try
for (x, p) in request
push!(xs, x)
push!(promises, p)
nitems += 1
if nitems == batchsize
foreach(put!, promises, f(xs))
empty!(xs)
empty!(promises)
nitems = 0
end
end
# classify any leftovers
if !isempty(xs)
foreach(put!, promises, f(xs))
end
catch
foreach(close, promises)
rethrow()
end
end
function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
promise = Channel{Y}(1)
put!(request, (x, promise))
return take!(promise)
end
function FOO(request::Channel)
x = 5
y = batched_call(request, x)
return x == y
end
function BAR(x)
return x
end
X = Y = Any
request = Channel{Tuple{X,Channel{Y}}}() # (32) - maybe better if it's buffered; tuning needed
ntasks = 100
@sync try
@async try
batch_processor(BAR, request)
finally
close(request)
end
for _ in 1:ntasks
Threads.@spawn try
FOO(request)
catch
close(request)
rethrow()
end
end
finally
close(request)
end
can you put three backticks before and after the code?
Yes! Sorry, my bad! I was trying to figure out how to format it.
It's a bit hard to diagnose since InvalidStateException("Channel is closed.", :closed)
is not the cause of the initial error. It's not elegant, but I often put
catch err
@error "Error at XXX" exception = (err, catch_backtrace())
in before each finally
block in each @spawn
and @async
, to figure out the initial cause.
This is the error message now:
┌ Error: Error in batch processor calls to BAR
│ exception = (InvalidStateException("Channel is closed.", :closed), Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x000000016adffc06, Ptr{Nothing} @0x000000016adffce7, Ptr{Nothing} @0x000000010a9bf015, Ptr{Nothing} @0x000000016adff4e8, Ptr{Nothing} @0x000000016adff7ac, Ptr{Nothing} @0x000000010a9bf015, Ptr{Nothing} @0x000000010a9da25d])
└ @ Main In[25]:66
Code is below:
function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
batchsize = 32
# list of images to classify
# support 32 elements, but if less, don't give empty elements back
xs = sizehint!(Vector{X}(undef, 0), batchsize)
promises = sizehint!(Vector{Promise}(undef, 0), batchsize)
# classify the images in batches
nitems = 0
try
for (x, p) in request
push!(xs, x)
push!(promises, p)
nitems += 1
if nitems == batchsize
foreach(put!, promises, f(xs))
empty!(xs)
empty!(promises)
nitems = 0
end
end
# classify any leftovers
if !isempty(xs)
foreach(put!, promises, f(xs))
end
catch
foreach(close, promises)
rethrow()
end
end
function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
promise = Channel{Y}(1)
put!(request, (x, promise))
return take!(promise)
end
function FOO(request::Channel)
x = 5
y = batched_call(request, x)
return x == y
end
function BAR(x)
return x
end
X = Y = Any
request = Channel{Tuple{X,Channel{Y}}}() # (32) - maybe better if it's buffered; tuning needed
ntasks = 100
@sync try
@async try
batch_processor(BAR, request)
catch err
@error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
finally
close(request)
end
for _ in 1:ntasks
Threads.@spawn try
FOO(request)
catch err
@error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
# catch
# close(request)
# rethrow()
end
end
catch err
@error "Error in sync calls to FOO" exception = (err, catch_backtrace())
finally
close(request)
end
Ok, so this is tricky. The close exceptions are due to that we are not waiting for the tasks after for _ in 1:ntasks
. This itself can be fixed by @sync for _ in 1:ntasks
. And it'd let you process the first calls as long as the number is a multiple of batchsize
. But the program would deadlock since nobody is closing the request channel.
Presumably, you have a list of things processed by FOO. So, in principle, FOO can close the request channel inside the last call to batched_call
just before take!(promise)
.
But it sounds like a tedious interface to use. I wonder what's the best way to handle this...
Since termination is rather tricky, I'd check if a dumb solution like just "processing batchsize
items in each FOO" works. If FOO processes a much larger number of items than batchsize
, I think it'd be much simpler.
Do you think asyncmap would be a good fit for this?http://web.mit.edu/julia_v0.6.0/julia/share/doc/julia/html/en/stdlib/parallel.html#Base.asyncmap
Alternatively, I’m wondering if something as simple as having a channel with max size 32, and then an atomic counter variable, so that whenever the counter hits 32, it reads 32 elements from the channel and processes them. Then, it would free up the channel.
If you have ncalls
calls to FOO (e.g., 100 here), processing (ncalls ÷ 32) * 32
elements is already possible with a small tweak to your MWE:
function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
batchsize = 32
# list of images to classify
# support 32 elements, but if less, don't give empty elements back
xs = sizehint!(Vector{X}(undef, 0), batchsize)
promises = sizehint!(Vector{Promise}(undef, 0), batchsize)
# classify the images in batches
nitems = 0
try
for (x, p) in request
push!(xs, x)
push!(promises, p)
nitems += 1
if nitems == batchsize
@info xs
foreach(put!, promises, f(xs))
empty!(xs)
empty!(promises)
nitems = 0
end
end
# classify any leftovers
if !isempty(xs)
foreach(put!, promises, f(xs))
end
catch
foreach(close, promises)
rethrow()
end
end
function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
promise = Channel{Y}(1)
put!(request, (x, promise))
return take!(promise)
end
function FOO(request::Channel, x)
x = 5
y = batched_call(request, x)
return x == y
end
function BAR(x)
return x
end
X = Y = Any
request = Channel{Tuple{X,Channel{Y}}}() # (32) - maybe better if it's buffered; tuning needed
ntasks = 100
@sync try
@async try
batch_processor(BAR, request)
catch err
@error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
rethrow()
finally
close(request)
end
@sync for _ in 1:ntasks
Threads.@spawn try
FOO(request)
catch err
@error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
rethrow()
# catch
# close(request)
# rethrow()
end
end
catch err
@error "Error in sync calls to FOO" exception = (err, catch_backtrace())
rethrow()
finally
close(request)
end
but the problem is rather how to end the "last" tasks executing FOO
I'm not sure if asyncmap is helpful here.
Maybe it's possible to write a generic solution if you can re-construct function FOO into the form
function FOO(request, input)
x = FOO_preprocess(input)
y = batched_call(request, x)
z = FOO_postprocess(y, input)
return z
end
Then a higher-order function can take FOO_preprocess
, FOO_postprocess
, BAR
, and batchsize
and call them appropriately.
Last updated: Nov 22 2024 at 04:41 UTC