Stream: helpdesk (published)

Topic: Task and Channel for Batch Scheduling


view this post on Zulip Yash Dalmia (May 15 2021 at 23:02):

Can anyone help me with a batch scheduling problem?

Description:
I have a function FOO that I want to run on multiple threads at once. While FOO is running, it will occasionally call another function BAR with an argument unique to each thread that FOO is running on. I want BAR to batch process whatever FOO passes to it (say 32 at a time), and return the values to each unique thread that is running FOO. Is there a way to do this using the Task Library or Channel Library?

If the description of foo and bar is too abstract, I essentially have threads running that each want to classify images, but I want to batch the images before sending them to the GPU to be classified.

view this post on Zulip Takafumi Arakaki (tkf) (May 15 2021 at 23:16):

Does FOO have to wait for BAR? If so, what's different from a simple blocking call after accumulating 32 elements?

But, since you use GPU, maybe you want to pipeline several calls of BAR from each FOO?

view this post on Zulip Yash Dalmia (May 15 2021 at 23:20):

Takafumi Arakaki (tkf) said:

Does FOO have to wait for BAR? If so, what's different from a simple blocking call after accumulating 32 elements?

But, since you use GPU, maybe you want to pipeline several calls of BAR from each FOO?

Yes, FOO needs to wait for BAR before it can continue. I essentially am just trying to figure out how to batch the calls, but ensure that each output from BAR goes to the exact FOO that called it. For example, I thought one way to do this might be create a tuples of the thread_id and the image, classify 32 at a time, and then let the worker threads take their respective images that they put in the channel for classification. But I was interested to know what the best mechanism for this type of idea is.

view this post on Zulip Takafumi Arakaki (tkf) (May 15 2021 at 23:58):

So, are 32 items coming from different tasks executing FOO? i.e., you want to combine 32 calls from different tasks to issue one GPU kernel call?

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:01):

...and want to make sure the result of the call goes back to the correct FOO

view this post on Zulip Yash Dalmia (May 16 2021 at 00:06):

Yes. Essentially, let’s say the

Thread 1 calls G with arg 1
Thread 2 calls G with arg 2
Etc etc
Thread n calls G with arg n

I want G (the gpu) to operate on argument 1,2...n in batches of 32 while making sure the result of g on arg1 returns to thread 1, the result of arg2 returns to thread 2, etc.

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:22):

OK, I now understand the question. You essentially need a promise (future). The easiest way to do this in Julia is to use single element "one-shot" Channel (I think Go programmers use it all the time).

function batch_processor(f, request::Channel)
    batchsize = 32
    # It's better to use two arrays but for now this is an array of 2-tuples :
    batch = sizehint!(Vector{eltype(Channel)}(undef, 0), batchsize)
    nitems = 0
    for r in request
        push!(batch, r)
        nitems += 1
        if nitems == batchsize
            xs = first.(batch)
            ys = f(xs)
            promises = last.(batch)
            foreach(put!, promises, ys)
            empty!(batch)
            nitems = 0
        end
    end
end

function batched_call(request::Channel, x)
    promise = Channel(1)
    # promise = Channel{TypeOfY}(1)  # if you know the result type of f
    put!(request, (x, promise))
    return take!(promise)
end

function FOO(request::Channel)
    ...
    y = batched_call(request, x)
    ...
end

request = Channel()
@sync try
    # request = Channel(32)  # maybe better if it's buffered; tuning needed
    @async try
        batch_processor(BAR, request)
    finally
        close(request)
    end
    for _ in 1:ntasks
        Threads.@@spawn try
            FOO(request)
        finally
            close(request)
        end
    end
finally
    close(request)
end

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:26):

I haven't done any extensive benchmarks, but here is a more direct implementation of promise than Channel(1) https://github.com/JuliaFolds/FoldsThreads.jl/blob/08329298b73c054e1552f7f4a358e2cdcaf89027/src/utils.jl#L42

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:26):

but, as you can see in the above code, the tedious part is more about how to use promises

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:27):

(btw, the above code is totally untested. there can be typos etc.)

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:30):

An interesting point is that it would be a useful pattern to use ntasks larger than nthraeds() or the number of CPUs. The scheduler will suspend the task waiting on take!(promise) and switch to a new task, executing on a shared worker pool on OS threads.

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:34):

(I just noticed that batch_processor might need try-finally around the for loop to close all pending promises, to ensure proper shutdown.)

view this post on Zulip Yash Dalmia (May 16 2021 at 00:42):

Thank you so much! This was incredibly helpful. I will take a look into Go style programming.

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:48):

You are welcome :) yeah, I think Go style is a good inspiration for doing this kind of things

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 00:50):

I was just tweaking my above post since I just keep noticing some gotchas :)

function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
    batchsize = 32
    xs = sizehint!(Vector{X}(undef, 0), batchsize)
    promises = sizehint!(Vector{Promise}(undef, 0), batchsize)
    nitems = 0
    try
        for (x, p) in request
            push!(xs, x)
            push!(promises, p)
            nitems += 1
            if nitems == batchsize
                foreach(put!, promises, f(xs))
                empty!(xs)
                empty!(promises)
                nitems = 0
            end
        end
        if !isempty(batch)
            foreach(put!, promises, f(xs))
        end
    catch
        foreach(close, promises)
        rethrow()
    end
end

function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
    promise = Channel{Y}(1)
    put!(request, (x, promise))
    return take!(promise)
end

function FOO(request::Channel)
    ...
    y = batched_call(request, x)
    ...
end

X = Y = Any  # or more concrete, if known
request = Channel{Tuple{X,Channel{Y}}}()
# request = Channel{Tuple{X,Channel{Y}}}(32)  # maybe better if it's buffered; tuning needed
@sync try
    @async try
        batch_processor(BAR, request)
    finally
        close(request)
    end
    for _ in 1:ntasks
        Threads.@@spawn try
            FOO(request)
        catch
            close(request)
            rethrow()
        end
    end
finally
    close(request)
end

view this post on Zulip Yash Dalmia (May 16 2021 at 02:28):

It seems really close to working but is throwing an error message as follows:

TaskFailedException
    nested task error: InvalidStateException("Channel is closed.", :closed)
    Stacktrace:
     [1] check_channel_state
       @ ./channels.jl:170 [inlined]
     [2] put!
       @ ./channels.jl:314 [inlined]
     [3] batched_call(request::Channel{Tuple{Any, Channel{Any}}}, x::Int64)
       @ Main ./In[22]:35
     [4] FOO(request::Channel{Tuple{Any, Channel{Any}}})
       @ Main ./In[22]:41
     [5] (::var"#38#40")()
       @ Main ./threadingconstructs.jl:169
...and 99 more exceptions.
Stacktrace:
 [1] sync_end(c::Channel{Any})
   @ Base ./task.jl:369
 [2] top-level scope
   @ task.jl:388
 [3] eval
   @ ./boot.jl:360 [inlined]
 [4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1094

This is the code I used

function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
    batchsize = 32

    # list of images to classify
    # support 32 elements, but if less, don't give empty elements back
    xs = sizehint!(Vector{X}(undef, 0), batchsize)
    promises = sizehint!(Vector{Promise}(undef, 0), batchsize)

    # classify the images in batches
    nitems = 0
    try
        for (x, p) in request
            push!(xs, x)
            push!(promises, p)
            nitems += 1
            if nitems == batchsize
                foreach(put!, promises, f(xs))
                empty!(xs)
                empty!(promises)
                nitems = 0
            end
        end
        # classify any leftovers
        if !isempty(xs)
            foreach(put!, promises, f(xs))
        end
    catch
        foreach(close, promises)
        rethrow()
    end
end

function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
    promise = Channel{Y}(1)
    put!(request, (x, promise))
    return take!(promise)
end

function FOO(request::Channel)
    x = 5
    y = batched_call(request, x)
    return x == y
end

function BAR(x)
    return x
end

X = Y = Any
request = Channel{Tuple{X,Channel{Y}}}()   # (32) - maybe better if it's buffered; tuning needed
ntasks = 100

@sync try
    @async try
        batch_processor(BAR, request)
    finally
        close(request)
    end
    for _ in 1:ntasks
        Threads.@spawn try
            FOO(request)
        catch
            close(request)
            rethrow()
        end
    end
finally
    close(request)
end

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 02:31):

can you put three backticks before and after the code?

view this post on Zulip Yash Dalmia (May 16 2021 at 02:31):

Yes! Sorry, my bad! I was trying to figure out how to format it.

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 02:39):

It's a bit hard to diagnose since InvalidStateException("Channel is closed.", :closed) is not the cause of the initial error. It's not elegant, but I often put

catch err
    @error "Error at XXX" exception = (err, catch_backtrace())

in before each finally block in each @spawn and @async, to figure out the initial cause.

view this post on Zulip Yash Dalmia (May 16 2021 at 02:46):

This is the error message now:

 Error: Error in batch processor calls to BAR
   exception = (InvalidStateException("Channel is closed.", :closed), Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x000000016adffc06, Ptr{Nothing} @0x000000016adffce7, Ptr{Nothing} @0x000000010a9bf015, Ptr{Nothing} @0x000000016adff4e8, Ptr{Nothing} @0x000000016adff7ac, Ptr{Nothing} @0x000000010a9bf015, Ptr{Nothing} @0x000000010a9da25d])
 @ Main In[25]:66

Code is below:

function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
    batchsize = 32

    # list of images to classify
    # support 32 elements, but if less, don't give empty elements back
    xs = sizehint!(Vector{X}(undef, 0), batchsize)
    promises = sizehint!(Vector{Promise}(undef, 0), batchsize)

    # classify the images in batches
    nitems = 0
    try
        for (x, p) in request
            push!(xs, x)
            push!(promises, p)
            nitems += 1
            if nitems == batchsize
                foreach(put!, promises, f(xs))
                empty!(xs)
                empty!(promises)
                nitems = 0
            end
        end
        # classify any leftovers
        if !isempty(xs)
            foreach(put!, promises, f(xs))
        end
    catch
        foreach(close, promises)
        rethrow()
    end
end

function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
    promise = Channel{Y}(1)
    put!(request, (x, promise))
    return take!(promise)
end

function FOO(request::Channel)
    x = 5
    y = batched_call(request, x)
    return x == y
end

function BAR(x)
    return x
end

X = Y = Any
request = Channel{Tuple{X,Channel{Y}}}()   # (32) - maybe better if it's buffered; tuning needed
ntasks = 100

@sync try
    @async try
        batch_processor(BAR, request)
    catch err
        @error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
    finally
        close(request)
    end
    for _ in 1:ntasks
        Threads.@spawn try
            FOO(request)

        catch err
            @error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())

#         catch
#             close(request)
#             rethrow()
        end
    end

catch err
    @error "Error in sync calls to FOO" exception = (err, catch_backtrace())

finally
    close(request)
end

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 03:48):

Ok, so this is tricky. The close exceptions are due to that we are not waiting for the tasks after for _ in 1:ntasks. This itself can be fixed by @sync for _ in 1:ntasks. And it'd let you process the first calls as long as the number is a multiple of batchsize. But the program would deadlock since nobody is closing the request channel.

Presumably, you have a list of things processed by FOO. So, in principle, FOO can close the request channel inside the last call to batched_call just before take!(promise).

But it sounds like a tedious interface to use. I wonder what's the best way to handle this...

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 04:36):

Since termination is rather tricky, I'd check if a dumb solution like just "processing batchsize items in each FOO" works. If FOO processes a much larger number of items than batchsize, I think it'd be much simpler.

view this post on Zulip Yash Dalmia (May 16 2021 at 05:00):

Do you think asyncmap would be a good fit for this?http://web.mit.edu/julia_v0.6.0/julia/share/doc/julia/html/en/stdlib/parallel.html#Base.asyncmap

Alternatively, I’m wondering if something as simple as having a channel with max size 32, and then an atomic counter variable, so that whenever the counter hits 32, it reads 32 elements from the channel and processes them. Then, it would free up the channel.

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 05:11):

If you have ncalls calls to FOO (e.g., 100 here), processing (ncalls ÷ 32) * 32 elements is already possible with a small tweak to your MWE:

function batch_processor(f, request::Channel{Tuple{X,Promise}}) where {X,Promise}
    batchsize = 32

    # list of images to classify
    # support 32 elements, but if less, don't give empty elements back
    xs = sizehint!(Vector{X}(undef, 0), batchsize)
    promises = sizehint!(Vector{Promise}(undef, 0), batchsize)

    # classify the images in batches
    nitems = 0
    try
        for (x, p) in request
            push!(xs, x)
            push!(promises, p)
            nitems += 1
            if nitems == batchsize
                @info xs
                foreach(put!, promises, f(xs))
                empty!(xs)
                empty!(promises)
                nitems = 0
            end
        end
        # classify any leftovers
        if !isempty(xs)
            foreach(put!, promises, f(xs))
        end
    catch
        foreach(close, promises)
        rethrow()
    end
end

function batched_call(request::Channel{Tuple{X,Channel{Y}}}, x::X) where {X,Y}
    promise = Channel{Y}(1)
    put!(request, (x, promise))
    return take!(promise)
end

function FOO(request::Channel, x)
    x = 5
    y = batched_call(request, x)
    return x == y
end

function BAR(x)
    return x
end

X = Y = Any
request = Channel{Tuple{X,Channel{Y}}}()   # (32) - maybe better if it's buffered; tuning needed
ntasks = 100


@sync try
    @async try
        batch_processor(BAR, request)
    catch err
        @error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
        rethrow()
    finally
        close(request)
    end
    @sync for _ in 1:ntasks
        Threads.@spawn try
            FOO(request)
        catch err
            @error "Error in batch processor calls to BAR" exception = (err, catch_backtrace())
            rethrow()

#         catch
#             close(request)
#             rethrow()
        end
    end

catch err
    @error "Error in sync calls to FOO" exception = (err, catch_backtrace())
    rethrow()

finally
    close(request)
end

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 05:13):

but the problem is rather how to end the "last" tasks executing FOO

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 05:16):

I'm not sure if asyncmap is helpful here.

view this post on Zulip Takafumi Arakaki (tkf) (May 16 2021 at 05:18):

Maybe it's possible to write a generic solution if you can re-construct function FOO into the form

function FOO(request, input)
   x = FOO_preprocess(input)
   y = batched_call(request, x)
   z = FOO_postprocess(y, input)
   return z
end

Then a higher-order function can take FOO_preprocess, FOO_postprocess, BAR, and batchsize and call them appropriately.


Last updated: Dec 28 2024 at 04:38 UTC