Any idea why this function is ~100x faster than Python for small arrays but the same for larger arrays?
function hd_loss(ŷ, y, ŷ_dtm, y_dtm)
Δ = (ŷ .- y) .^ 2
dtm = (ŷ_dtm .^ 2) + (y_dtm .^ 2)
@tullio M[x, y, z] := Δ[x, y, z] * dtm[x, y, z]
hd_loss = mean(M)
end
When I time it for small arrays size = (4, 4, 2)
I get results ~100 faster than Python
@btime hd_loss(ŷ, y, ŷ_dtm, y_dtm)
# 495.728 ns (6 allocations: 1.66 KiB)
Compared to Python
%timeit hd_loss(y, y_hat, gt_dtm, seg_dtm)
# (36.6 µs ± 260 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each))
When I time for larger arrays size = (96, 96, 96)
I get results on par with Python
@btime hd_loss(Ŷ, Y, Ŷ_dtm, Y_dtm)
# 2.395 ms (55 allocations: 33.75 MiB)
Compared to Python
%timeit hd_loss(large_y, large_y_hat, large_gt_dtm, large_seg_dtm)
# 2.99 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Here is the Python function
def hd_loss(seg_soft, gt, seg_dtm, gt_dtm):
"""
compute hausdorff distance loss for binary segmentation
"""
delta_s = (seg_soft - gt) ** 2
s_dtm = seg_dtm ** 2
g_dtm = gt_dtm ** 2
dtm = s_dtm + g_dtm
multipled = torch.einsum("xyz, xyz->xyz", torch.from_numpy(delta_s), torch.from_numpy(dtm))
hd_loss = multipled.mean()
return hd_loss
Python incurs an initial penalty because the code needs to be interpreted first. However, by the time you get to large matrices the initial overhead is now dominated by the calculation itself. In both languages most of that computation is actually being done in native code
Ahh, okay thanks for clearing that up!
If you had tried to implement that einsum
in Python itself, it would be painfully slow. Instead though, you call torch 's einsum
which is implemented in C++
Also, you shouldn't benchmark in global scope, instead you should use something like
@btime hd_loss($Ŷ, $Y, $Ŷ_dtm, $Y_dtm)
Hm, if you use views where applicable and in-place dotted functions, Julia should still be faster than Python.
I think you can avoid allocating 3 temporary arraysΔ
, dtm
and M
, like so:
function hd_loss_2(ŷ, y, ŷ_dtm, y_dtm)
M = @. (ŷ - y)^2 * (ŷ_dtm^2 + y_dtm^2) # allocates just one array
mean(M)
end
function hd_loss_3(ŷ, y, ŷ_dtm, y_dtm)
@tullio tot := (ŷ[i,j,k] .- y[i,j,k])^2 * (ŷ_dtm[i,j,k] ^ 2 + y_dtm[i,j,k] ^ 2)
hd_loss = tot / length(y)
end
In fact, dtm = (ŷ_dtm .^ 2) + (y_dtm .^ 2)
allocates 3 arrays, so the original hd_loss
allocates 5 in total. With dtm = (ŷ_dtm .^ 2) .+ (y_dtm .^ 2)
it would fuse the operations. (Missed this on first reading, and my first hd_loss_2
had a similar mistake. Now corrected -- @.
is a good way not to miss any dots.)
Times (on a slow computer):
julia> N = 4; ŷ, y, ŷ_dtm, y_dtm = rand(N,N,N), rand(N,N,N), rand(N,N,N), rand(N,N,N);
julia> @btime hd_loss($ŷ, $y, $ŷ_dtm, $y_dtm)
1.344 μs (5 allocations: 3.05 KiB)
0.10481320873450592
julia> @btime hd_loss_2($ŷ, $y, $ŷ_dtm, $y_dtm)
439.258 ns (1 allocation: 624 bytes)
0.10481320873450592
julia> @btime hd_loss_3($ŷ, $y, $ŷ_dtm, $y_dtm)
156.461 ns (1 allocation: 16 bytes)
0.10481320873450593
julia> N = 96; ŷ, y, ŷ_dtm, y_dtm = rand(N,N,N), rand(N,N,N), rand(N,N,N), rand(N,N,N);
julia> @btime hd_loss($ŷ, $y, $ŷ_dtm, $y_dtm)
13.657 ms (26 allocations: 33.75 MiB)
0.11076876938711891
julia> @btime hd_loss_2($ŷ, $y, $ŷ_dtm, $y_dtm)
4.257 ms (2 allocations: 6.75 MiB)
0.11076876938711891
julia> @btime hd_loss_3($ŷ, $y, $ŷ_dtm, $y_dtm)
1.854 ms (41 allocations: 2.70 KiB)
0.11076876938711909
Thank you! That is much faster and easier to read
Last updated: Nov 06 2024 at 04:40 UTC