Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling performance b-splines model chapter 04 #233

Open
hannes-tw opened this issue Dec 11, 2023 · 1 comment
Open

Sampling performance b-splines model chapter 04 #233

hannes-tw opened this issue Dec 11, 2023 · 1 comment

Comments

@hannes-tw
Copy link

Hello.

I am running into some performance issues when specifying and sampling from a linear model with b-splines in PYMC (chapter 4, R Code reference 4.76). Everything works fast as long as I keep the number of knots and/or the degree low enough (below 10 knots and 2 degrees). As soon as I increase one of the parameters above the mentioned values, the model no longer samples within ~30 seconds, but takes 20 minutes+.

Has anybody else run into similar issues? I am currently puzzled by the rapid performance drop...
Any help or pointers would be appreciated.

Code

import numpy as np
import pymc as pm
from patsy import dmatrix

# data: cherry_blossoms

n_knots = 10
b_spline_degree = 2
years = cherry_blossoms["year"].values
day_of_blossom = cherry_blossoms["doy"].values

knots = np.quantile(years, np.linspace(0, 1, n_knots))
b_spline_basis_dm = dmatrix(
    "bs(year, knots=knots, degree=b_spline_degree, include_intercept=True) - 1",
    {
        "year": years,
        "knots": knots[1:-1],
        "b_spline_degree": b_spline_degree,
    },
)

b_spline_basis = np.asarray(b_spline_basis_dm)

with pm.Model() as b_spline_model:
    alpha = pm.Normal("alpha", mu=100, sigma=10)
    beta = pm.Normal("beta", mu=0, sigma=10, shape=b_spline_basis.shape[1])
    sigma = pm.Exponential("sigma", 1)
    mu = pm.Deterministic("mu", alpha + pm.math.dot(b_spline_basis, beta.T))
    likelihood = pm.Normal(
        "likelihood", mu=mu, sigma=sigma, observed=day_of_blossom
    )

    trace = pm.sample(draws=1000, tune=1000, chains=4)

Versions:
name : pymc
version : 5.9.1

name : numpy
version : 1.25.2

Best htw

@ajaya0
Copy link

ajaya0 commented Jan 16, 2024

can you explain more ??

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants