Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jitclass apraisal and performance revisit request. #9515

Open
dgrigonis opened this issue Mar 27, 2024 · 4 comments
Open

jitclass apraisal and performance revisit request. #9515

dgrigonis opened this issue Mar 27, 2024 · 4 comments
Labels
jitclass performance performance related issue

Comments

@dgrigonis
Copy link

dgrigonis commented Mar 27, 2024

Hello,

First of all, I would like to take time to appreciate jitclass. It is the only way to emulate coroutines, that take in values.

Given, various complications of passing function as an argument to numba function, I found this method to be very attractive for cases of certain type. I.e. Using jitclass it to emulate loop that stops from time to time and asks for new values.

E.g.

jc = JitClass()
while True:
    request = jc.next()
    response = parse_request(request)
    request = jc.next(response)

Although there is an overhead of function calls, but it does increase performance for certain cases while factoring out operations that are best left out of numba.


Having that said, my request:
I would like to ask for some optimizations for jitclass.


And observations:

In short properties are faster than pure python. But static methods and instance methods are unbelievably slow. func below is analogous simple njit function to stat static method of jitclass object. - performance of these 2 should be similar, but are very different. However, the main bottleneck for my case becomes meth instance methods. More complex methods are ok, as call overhead is much smaller compared to runtime of the body (such as incr(ement) below). However, to gain the most from applications such as I described above the overhead needs to be optimized.

lambdas = {
    'init': [lambda: BagPy(n), lambda: BagNb(n)],
    'prop': [lambda: b_py.size, lambda: b_nb.size],
    'meth': [lambda: b_py.add_im(1, 2), lambda: b_nb.add_im(1, 2)],
    'stat': [lambda: b_py.add(1, 2), lambda: b_nb.add(1, 2)],
    'func': [lambda: add_py(1, 2), lambda: add_nb(1, 2)],
    'incr': [lambda: b_py.increment(1), lambda: b_nb.increment(1)],
}

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃   5 repeats, 1,000 times     ┃
┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫
┃  Units: ns       py       nb ┃
┃            ┏━━━━━━━━━━━━━━━━━┫
┃       init533     6667 ┃
┃       prop92       69 ┃
┃       meth64      699 ┃
┃       stat77      718 ┃
┃       func56      141 ┃
┃       incr40867     1220 ┃
┗━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━┛

See code below:

import numba as nb
import numpy as np


class BagPy:
    def __init__(self, value):
        self.value = value
        self.array = np.zeros(value, dtype=np.float32)
    @property
    def size(self):
        return self.array.size
    def add_im(self, x, y):
        return x + y
    @staticmethod
    def add(x, y):
        return x + y
    def increment(self, val):
        for i in range(self.size):
            self.array[i] += val
        return self.array

def add_py(a, b):
    return a + b


@nb.experimental.jitclass([
    ('value', nb.int32),               # a simple scalar field
    ('array', nb.float32[:]),          # an array field
])
class BagNb:
    def __init__(self, value):
        self.value = value
        self.array = np.zeros(value, dtype=np.float32)
    @property
    def size(self):
        return self.array.size
    def add_im(self, x, y):
        return x + y
    @staticmethod
    def add(x, y):
        return x + y
    def increment(self, val):
        for i in range(self.size):
            self.array[i] += val
        return self.array

@nb.njit
def add_nb(a, b):
    return a + b


n = 21
b_py = BagPy(n)
b_nb = BagNb(n)
@guilhermeleobas guilhermeleobas added jitclass performance performance related issue labels Mar 28, 2024
@guilhermeleobas
Copy link
Collaborator

@dgrigonis, could you share the code you used to build the table presented above?

@dgrigonis
Copy link
Author

Can't do the exact same code, but for quick reproducibility, this looks almost the same:

from random import random
from statistics import mean, stdev
from time import sleep
from timeit import timeit
from terminaltables import SingleTable


def timeit_lambda_dict(lambdas, repeats=5, n=100, u='µs', std=False, colnames=()):
    mult = {'s': 1, 'ms': 1e3, 'µs': 1e6, 'ns': 1e9}
    m = mult[u]
    rows = [[f'units: {u}'] + list(colnames)]
    for k, values in lambdas.items():
        row = [k]
        for v in values:
            reps = [timeit(v, number=n) / n * m for _ in range(repeats)]
            s = f'{int(mean(reps)):>4}'
            if std:
                s += f' ± {int(stdev(reps)):>3}'
            row.append(s)
        rows.append(row)
    table = SingleTable(rows, title=f'{repeats} repeats, {n:,} times')
    print(table.table)

timeit_lambda_dict(lambdas, n=1000, colnames=['py', 'nb'], u='ns')

@guilhermeleobas
Copy link
Collaborator

guilhermeleobas commented Mar 28, 2024

This is just a guess. The first time you call, Numba has to compile the Python code down to LLVM and this is quite expensive. Especially using a benchmark that the runtime is pretty small (nanoseconds). The second table shows the result by repeating the execution 100 times, instead of just 5. Also, I think there's an overhead of calling the compiled function which is expensive if compared using your example.

$ python repro.py
┌5 repeats, 1,000 times─────┐
│ units: ns │ py    │ nb    │
├───────────┼───────┼───────┤
│ init      │  517  │ 3636  │
│ prop      │   94  │   82  │
│ meth      │   87  │ 8981  │
│ stat      │   90  │ 8387  │
│ func      │   73  │ 3388  │
│ incr      │ 26696 │ 13015 │
└───────────┴───────┴───────┘
┌100 repeats, 1,000 times──┐
│ units: ns │ py    │ nb   │
├───────────┼───────┼──────┤
│ init      │  495  │ 3535 │
│ prop      │   91  │   78 │
│ meth      │   86  │  473 │
│ stat      │   88  │  470 │
│ func      │   72  │  114 │
│ incr      │ 26083 │  730 │
└───────────┴───────┴──────┘

@dgrigonis
Copy link
Author

dgrigonis commented Mar 28, 2024

Yeah, I think I did do a pre-run in my benchmarks.

There is 1 very obvious straight forward problem. stat vs func. One is static method, another is a function. Theoretically, they should be identical.

I am not sure about the others, but simple class method seems to be behaving the same. If wander if they both can be improved to call time of a simple func at the same time.

Also, to add:
Compilation is very slow for jitclass. The one I have built takes 15 seconds to compile.

To see the result without compilation just run the table twice in the same script. Once just run everything a single time and do a proper one after. However, your 2nd table with 100 x 1000 seems to have managed to show near-true picture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jitclass performance performance related issue
Projects
None yet
Development

No branches or pull requests

2 participants