Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to improve CPU performance? #357

Open
Moelf opened this issue Feb 13, 2023 · 0 comments
Open

How to improve CPU performance? #357

Moelf opened this issue Feb 13, 2023 · 0 comments

Comments

@Moelf
Copy link
Contributor

Moelf commented Feb 13, 2023

consider this example from a Jax discussion:

Source code
function run_julia(height, width)
    y = range(-1.0f0, 0.0f0; length = height) # need Float32 because Jax defaults to it
    x = range(-1.5f0, 0.0f0; length = width)
    c = x' .+ y*im
    fractal = fill(Int32(20), height, width)

    # this checks if indicies are compatible between `c` and `fractal`
    @inbounds for idx in eachindex(c, fractal)
        _c = c[idx]
        z = _c
        m = true
        Base.Cartesian.@nexprs 20 i -> begin
            z = z^2 + _c
            az4 = abs2(z) > 4f0
            fractal[idx] = ifelse(m&az4, Int32(i), fractal[idx]) # 32-bit Int, same reason as above
            m &= (!az4)
        end
    end
    return fractal
end

using KernelAbstractions
    
@kernel function julia_kernel!(c, fractal)    
    I = @index(Global)    
    _c = c[I]    
    z = _c       
    @inbounds for i = 1:20    
        z = z^2 + _c    
        if abs2(z) > 4f0    
            fractal[I] = Int32(i)    
            break    
        end    
    end    
end

function run_julia_cpu_jaxstype(height, width)
    y = range(-1.0f0, 0.0f0; length = height)  
    x = range(-1.5f0, 0.0f0; length = width)  
    c = x' .+ y*im
    fractal = fill(Int32(20), height, width)     

    kernel! = julia_kernel!(CPU(), length(c)÷Threads.nthreads()) # we're using 1-thread here 
    event = kernel!(c, fractal; ndrange=length(c))
    wait(event) # not copying back, need to block here
    return fractal
end

we have

julia> Threads.nthreads()
1

julia> @benchmark run_julia(2000,3000)
BenchmarkTools.Trial: 100 samples with 1 evaluation.
 Range (min  max):  49.380 ms   52.606 ms  ┊ GC (min  max): 0.37%  1.04%
 Time  (median):     49.789 ms               ┊ GC (median):    0.98%
 Time  (mean ± σ):   49.982 ms ± 583.818 μs  ┊ GC (mean ± σ):  0.98% ± 0.07%

      ▅█▇▄   ▃
  ▃▁▅▇█████▃▇██▅▇▅▁▆▆▃▆▃▁▆▃▅▁▃▁▁▁▁▁▃▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▆ ▃
  49.4 ms         Histogram: frequency by time           52 ms <

 Memory estimate: 68.66 MiB, allocs estimate: 4.

julia> @benchmark run_julia_cpu_jaxstype(2000,3000)
BenchmarkTools.Trial: 29 samples with 1 evaluation.
 Range (min  max):  176.674 ms  179.932 ms  ┊ GC (min  max): 0.27%  0.08%
 Time  (median):     177.228 ms               ┊ GC (median):    0.26%
 Time  (mean ± σ):   177.528 ms ± 780.443 μs  ┊ GC (mean ± σ):  0.26% ± 0.03%

  ▃▃ ▃▃  ▃                  █
  ██▇██▇▁█▇▇▇▁▇▁▁▁▁▇▁▁▇▇▁▁▇▇█▇▁▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  177 ms           Histogram: frequency by time          180 ms <

 Memory estimate: 68.67 MiB, allocs estimate: 28.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant