Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove numba dependency #753

Open
jmoralez opened this issue Jan 9, 2024 · 2 comments
Open

remove numba dependency #753

jmoralez opened this issue Jan 9, 2024 · 2 comments

Comments

@jmoralez
Copy link
Member

jmoralez commented Jan 9, 2024

Description

We've heavily relied on numba to speed up our models, however we don't have the need for its JIT compilation, since the code that uses it is defined inside the library.

Replacing numba jitted code with a compiled alternative (C++ or Rust for example) would provide the following benefits:

And the following drawbacks:

  • More time to develop new models.
  • It would make it more difficult for users to install from github since they'd need to have a compiler installed. Although we could publish nightly wheels for this.

Use case

This will benefit the development process since even when using the cache it can take a couple of seconds to run jitted functions for the first time.
Also deployments would be smoother because either:

  1. People that were copying over the cache won't have to do it anymore.
  2. People that weren't copying the cache won't experience cold starts anymore.
@AzulGarza
Copy link
Member

i love this. i think we should move forward with the enhancement.

regarding the second drawback, i agree with the nightly wheels option. also adopting a compiled alternative will force us to make releases more often. 🙌

regarding the first drawback, we could still have some models relying on numba as an optional dependency. for example, if we release a new model written in numba, it will be only available installing statsforecast[numba]. this approach could help us iterate new models faster and once they are stable, we could migrate them to their compiled version. :)

@jmoralez
Copy link
Member Author

I used the following to profile the current compilation times:

Click to expand
import datetime
import operator
import os
os.environ.pop('NIXTLA_NUMBA_CACHE', None)
from collections import defaultdict

import numpy as np
import pandas as pd
from numba.core.event import install_recorder

from statsforecast.core import StatsForecast
from statsforecast.models import *
from statsforecast.utils import generate_series

data = generate_series(2)
models = [
    AutoARIMA(season_length=7),
    AutoCES(season_length=7),
    AutoETS(season_length=7),
    AutoTheta(season_length=7),    
    SimpleExponentialSmoothing(alpha=0.1),
    GARCH(),
    TBATS(seasonal_periods=7),
]
sf = StatsForecast(models=models, freq='D')
with install_recorder("numba:compile") as rec:
    forecast = sf.forecast(df=data, h=7)

events = defaultdict(dict)
for ts, event in rec.buffer:
    if event.is_start:
        stage = 'start'
    else:
        stage = 'end'
    events[event.data['dispatcher']][stage] = ts

comp_times_ms = []
for fn, times in events.items():
    module =  fn.py_func.__module__
    if not module.startswith('statsforecast'):
        continue
    name = f'{module}.{fn.__name__}'
    start = datetime.datetime.fromtimestamp(times['start'])
    end = datetime.datetime.fromtimestamp(times['end'])
    time_in_ms = round((end - start).microseconds / 1000)
    comp_times_ms.append((name, time_in_ms))

top_fns = sorted(comp_times_ms, key=operator.itemgetter(1), reverse=True)

times_by_module = defaultdict(int)
for fn, time in top_fns:
    times_by_module[fn.split('.')[1]] += time
top_modules = sorted(times_by_module.items(), key=operator.itemgetter(1), reverse=True)

And got the following results:

Times in milliseconds by function:

Click to expand
[('statsforecast.theta.initstate', 936),
 ('statsforecast.ces.switch_ces', 907),
 ('statsforecast.theta.initparamtheta', 881),
 ('statsforecast.arima._make_arima', 849),
 ('statsforecast.ces.cesfcst', 828),
 ('statsforecast.ces.ces_target_fn', 823),
 ('statsforecast.tbats.makeTBATSFMatrix', 793),
 ('statsforecast.garch.garch_sigma2', 739),
 ('statsforecast.ets.ets_target_fn', 734),
 ('statsforecast.theta.pegelsresid_theta', 696),
 ('statsforecast.ets.pegelsresid_C', 690),
 ('statsforecast.tbats.calcTBATSFaster', 683),
 ('statsforecast.ets.nelder_mead_ets', 664),
 ('statsforecast.theta.theta_target_fn', 605),
 ('statsforecast.arima.arima_gradtrans', 564),
 ('statsforecast.ces.initparamces', 557),
 ('statsforecast.arima.arima_css', 546),
 ('statsforecast.theta.thetafcst', 513),
 ('statsforecast.theta.nelder_mead_theta', 503),
 ('statsforecast.ces.nelder_mead_ces', 490),
 ('statsforecast.ces.cescalc', 486),
 ('statsforecast.theta.thetacalc', 483),
 ('statsforecast.arima.getQ0', 477),
 ('statsforecast.ets.etscalc', 476),
 ('statsforecast.ets.etsforecast', 453),
 ('statsforecast.arima.diff1d', 435),
 ('statsforecast.theta.switch_theta', 432),
 ('statsforecast.ces.cesupdate', 379),
 ('statsforecast.arima.partrans', 370),
 ('statsforecast.arima.ARIMA_invtrans', 351),
 ('statsforecast.ets.update', 327),
 ('statsforecast.ets.restrict_to_bounds', 325),
 ('statsforecast.garch.garch_loglik', 288),
 ('statsforecast.models._ses_fcst_mse', 286),
 ('statsforecast.ets.switch', 274),
 ('statsforecast.arima.inclu2', 256),
 ('statsforecast.arima.arima_undopars', 256),
 ('statsforecast.theta.thetaupdate', 255),
 ('statsforecast.ets.forecast', 252),
 ('statsforecast.arima.tsconv', 240),
 ('statsforecast.ces.cesforecast', 230),
 ('statsforecast.arima.arima_like', 209),
 ('statsforecast.theta.thetaforecast', 193),
 ('statsforecast.arima.invpartrans', 189),
 ('statsforecast.theta.is_constant', 176),
 ('statsforecast.arima.arima_transpar', 174),
 ('statsforecast.ets.is_constant', 161),
 ('statsforecast.ets.initparam', 151),
 ('statsforecast.ces.pegelsresid_ces', 128),
 ('statsforecast.garch.garch_cons', 72),
 ('statsforecast.arima.kalman_forecast', 22)]

Times in milliseconds by module:

[('theta', 5673),
 ('arima', 4938),
 ('ces', 4828),
 ('ets', 4507),
 ('tbats', 1476),
 ('garch', 1099),
 ('models', 286)]

So I believe we can migrate them in that order (I already migrated ETS in #757 because I profiled this wrong xD) but we can continue with Theta next.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants