Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can we make KA fast on CPUs? #509

Open
avik-pal opened this issue Aug 20, 2024 · 10 comments
Open

How can we make KA fast on CPUs? #509

avik-pal opened this issue Aug 20, 2024 · 10 comments

Comments

@avik-pal
Copy link
Contributor

See LuxDL/LuxLib.jl#136 for some background context. The main motivation for me is to avoid code duplication between CPU and GPU versions. However, if you take a look at the benchmark comment on the PR (for batchnorm and groupnorm) you see somewhere between a 10x-40x slowdown between KA and the equivalent optimized loop version (note that it is simply using @simd or @simd ivdep and nothing like LoopVectorization).

I think there are a couple of reasons for the slowdown:

  1. @simd annotations are missing (which causes slowdown even in the loop version if I remove the annotations)
  2. threading has overhead for some of the smaller problems

Potential solutions:

  1. Allow users to control threading. [FR] Add nthreads argument to CPU backend #507. For smaller problems, I want to opt out of threading manually.
  2. @simd annotations (Make CPU loops simd & ivdep #436 seems to do this. not sure what is the status for that)
  3. Alternate threading: KA is being used inside "core" operations. As such we are unlikely (if not impossible) to call other operations that make use of threading. Hence, having the option to use "cheaper threads" (Polyester.jl) would be a great addition
@vchuravy
Copy link
Member

Does #436 speed things up for you? I haven't merged it since I haven't seen an impact on benchmarks.

@avik-pal
Copy link
Contributor Author

avik-pal commented Aug 22, 2024

Let me run the benchmarks with that branch and check.

EDIT: doesn't seem to help

@vchuravy
Copy link
Member

So the question would be where are the overheards coming from, so maybe you cana run a profile and compare the time being spend. You could also use static scheduling.

Could you isolate a particular case where you are seeing these overheads?

@maleadt
Copy link
Member

maleadt commented Sep 18, 2024

A MWE would be very useful, as the upcoming POCL CPU back-end may be interesting for performance, but hasn't been benchmarked.

@avik-pal
Copy link
Contributor Author

Oops this fell off my radar, I will create a self contained example this week

@avik-pal
Copy link
Contributor Author

(Independent of this) once POCL is ready and I can access a ci server with that, I can trigger the benchmark suite for NN primitives in LuxLib to generate results in https://luxdl.github.io/LuxLib.jl/benchmarks/

@maleadt
Copy link
Member

maleadt commented Sep 20, 2024

Good to know. OpenCL.jl's KA.jl integration is actually functional already, see JuliaGPU/OpenCL.jl#233, but not tagged yet. If you're feeling adventurous, you can test it out already: Just add the master branch of OpenCL.jl and its SPIRVIntrinsics.jl subpackage from lib/intrinsics, do using OpenCL, pocl_jll, and you'll have a functional OpenCLBackend to use with KA.jl (in combination with a CLArray type for the usual stuff). Best to do so on a x86_64 machine, as our aarch64 build of pocl is known to have performance issues.

I expect issues with real applications though once testing this out, as the KA.jl and GPUArrays.jl test suites don't cover e.g. all array constructors or expected interfaces. Such feedback would be valuable!

@avik-pal
Copy link
Contributor Author

A toned down version of the 2d batchnorm routine from LuxLib

using Statistics, LinearAlgebra
using KernelAbstractions
using OpenCL, pocl_jll
using Test

function batchnorm_looped(x::AbstractMatrix, scale::AbstractVector, bias::AbstractVector, μ::AbstractVector, σ²::AbstractVector)
    y = similar(x)
    @inbounds for j in axes(x, 2)
        @simd ivdep for i in axes(x, 1)
            y[i, j] = tanh(scale[i] * (x[i, j] - μ[i]) / sqrt(σ²[i] + 1e-5) + bias[i])
        end
    end
    return y
end

@kernel function batchnorm_kernel_act!(y, @Const(act), @Const(scale), @Const(bias), @Const(x), @Const(μ), @Const(σ²))
    i, j = @index(Global, NTuple)
    y[i, j] = act(scale[i] * (x[i, j] - μ[i]) / sqrt(σ²[i] + 1e-5) + bias[i])
end

function batchnorm_ka(x::AbstractMatrix, scale::AbstractVector, bias::AbstractVector, μ::AbstractVector, σ²::AbstractVector)
    y = similar(x)
    backend = KernelAbstractions.get_backend(x)
    kernel! = batchnorm_kernel_act!(backend)
    kernel!(y, tanh, scale, bias, x, μ, σ²; ndrange=size(x))
    KernelAbstractions.synchronize(backend)
    return y
end

N, B = 32, 32

x = randn(Float32, N, B);
scale = randn(Float32, N);
bias = randn(Float32, N);
μ = rand(Float32, N);
σ² = rand(Float32, N);
x_cl = CLArray(x);
scale_cl = CLArray(scale);
bias_cl = CLArray(bias);
μ_cl = CLArray(μ);
σ²_cl = CLArray(σ²);

@test batchnorm_ka(x, scale, bias, μ, σ²)  batchnorm_looped(x, scale, bias, μ, σ²)

@benchmark batchnorm_looped($x, $scale, $bias, $μ, $σ²)
@benchmark batchnorm_ka($x, $scale, $bias, $μ, $σ²)
@benchmark batchnorm_ka($x_cl, $scale_cl, $bias_cl, $μ_cl, $σ²_cl)
julia> @benchmark batchnorm_looped($x, $scale, $bias, $μ, $σ²)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  15.391 μs …  50.510 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     15.511 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   15.606 μs ± 645.765 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▂▆▇██▇▇▆▆▅▅▄▄▃▃▁▂▁▂▁▁▁▁▁▁▁▁▁▁ ▁                             ▃
  ▆██████████████████████████████████████▆▇▆▆▆▅▄▅▄▅▃▃▃▄▄▄▃▅▁▃▃ █
  15.4 μs       Histogram: log(frequency) by time      16.6 μs <

 Memory estimate: 4.08 KiB, allocs estimate: 3.

julia> @benchmark batchnorm_ka($x, $scale, $bias, $μ, $σ²)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  164.321 μs …  15.296 ms  ┊ GC (min … max): 0.00% … 98.61%
 Time  (median):     166.691 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   181.080 μs ± 218.323 μs  ┊ GC (mean ± σ):  3.55% ±  3.91%

    ▂█▁                                                          
  ▃▆███▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▄▄▄▄▃▂▂▃▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂ ▃
  164 μs           Histogram: frequency by time          200 μs <

 Memory estimate: 52.41 KiB, allocs estimate: 3084.

julia> @benchmark batchnorm_ka($x_cl, $scale_cl, $bias_cl, $μ_cl, $σ²_cl)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   78.521 μs …   3.324 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     402.148 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   400.169 μs ± 186.617 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▁▇█▇▇▅▅▃▃▃▂▃▇▅▄▅▂▃▂▄▃▄▃▄▃▇▃▂ █▇▄▃▂▃▂▂▃█▇▂▃▇▄▃▆▆▃▃▅▇▅▆▅▂▂▃▆▅▃  
  ▃████████████████████████████████████████████████████████████ █
  78.5 μs          Histogram: frequency by time          707 μs <

 Memory estimate: 6.45 KiB, allocs estimate: 96.

@maleadt
Copy link
Member

maleadt commented Sep 27, 2024

That's not too bad, I guess? The minimal performance at least is somewhat decent, but I'd be curious where the huge spread comes from. We haven't done any optimization of OpenCL.jl, so there's probably lots of low-hanging fruit there. Additionally, the PoCL build is not optimized: Anything but X86 will perform bad, and we're currently using the default pthread CPU driver while there's also OpenMP/oneTBB-based drivers that are known to perform better.

@avik-pal
Copy link
Contributor Author

Digging a bit further, for the KA CPU backend, if I pull out the activation the performance is much better

@kernel inbounds=true function batchnorm_kernel_act!(y, @Const(scale), @Const(bias), @Const(x), @Const(μ), @Const(σ²))
    i, j = @index(Global, NTuple)
    res = scale[i] * (x[i, j] - μ[i]) / sqrt(σ²[i] + 1e-5) + bias[i]
    y[i, j] = res
end

function batchnorm_ka(x::AbstractMatrix, scale::AbstractVector, bias::AbstractVector, μ::AbstractVector, σ²::AbstractVector)
    y = similar(x)
    backend = KernelAbstractions.get_backend(x)
    kernel! = batchnorm_kernel_act!(backend)
    kernel!(y, scale, bias, x, μ, σ²; ndrange=size(x))
    @. y = tanh(y)
    KernelAbstractions.synchronize(backend)
    return y
end
julia> @benchmark batchnorm_looped($x, $scale, $bias, $μ, $σ²)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  15.380 μs   53.000 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     15.510 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   15.586 μs ± 646.284 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

     ▂▁▃▄███▇▇▆▅▅▆▃▃▂▃▁▁ ▂▁▁ ▁   ▁   ▁                         ▂
  ▇████████████████████████████▇▇█▆███▇█▇██▇▇██▇▇██▇▇█▇▆▇▇▇▆▅▅ █
  15.4 μs       Histogram: log(frequency) by time      16.1 μs <

 Memory estimate: 4.08 KiB, allocs estimate: 3.

julia> @benchmark batchnorm_ka($x, $scale, $bias, $μ, $σ²)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  11.250 μs   50.540 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     11.430 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   11.493 μs ± 555.961 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

          █  ▃ ▂▇  ▁                                            
  ▁▁▃▃▂▃▃▅██▆████▅▄█▄▅▃▃▄▂▂▃▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  11.2 μs         Histogram: frequency by time         12.1 μs <

 Memory estimate: 4.41 KiB, allocs estimate: 12.

Makes opencl slightly slower

julia> @benchmark batchnorm_ka($x_cl, $scale_cl, $bias_cl, $μ_cl, $σ²_cl)
BenchmarkTools.Trial: 9522 samples with 1 evaluation.
 Range (min  max):   99.260 μs    1.161 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     533.078 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   521.982 μs ± 241.771 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▂▅▅▄▃▁ ▂▁▁▁▁▂▁▁ ▂  ▃ ▂▁▁▂▂▁ ▄▁  ▄▂ ▁▃▄▂▁ ▂▄▃  ▃▃▅ ▃▃▂█▄▂▄     
  ▄██████████████████▇█▆██████▇██▇███▆█████▇████████▇████████▇▆ ▆
  99.3 μs          Histogram: frequency by time          921 μs <

 Memory estimate: 8.61 KiB, allocs estimate: 132.

Even if I pass act = identity to the KA kernel, the time shoots up to 160us. I would have assumed identity gets compiled out?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants