Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and Loading Models #671

Open
hxk1633 opened this issue May 3, 2023 · 7 comments
Open

Saving and Loading Models #671

hxk1633 opened this issue May 3, 2023 · 7 comments

Comments

@hxk1633
Copy link

hxk1633 commented May 3, 2023

Is there a way to save a fitted model to disk and then load it later to make predictions?

@tomicapretto
Copy link
Collaborator

Currently there's no way to do so. You could write something to store the inference data and the metadata of the model (the formula, the model family, the priors, etc.) and then load it again. It's along the lines of what I would be interested in doing, but I haven't had time lately.

@5hv5hvnk
Copy link

5hv5hvnk commented Jul 7, 2023

@hxk1633 You can try something like this from pymc-experimental

@humana
Copy link

humana commented Apr 5, 2024

Currently there's no way to do so. You could write something to store the inference data and the metadata of the model (the formula, the model family, the priors, etc.) and then load it again. It's along the lines of what I would be interested in doing, but I haven't had time lately.

Could you provide some details about how I could go about implementing something like this? Currently we have a model built with Bambi and we need to do predictions in production on new data that changes every few hours. I really need a way to serialise the Bambi model so it can be loaded again and serve predictions via an API.

@tomicapretto
Copy link
Collaborator

@humana have a look at the following example

Script 1 This is where you first created and "trained" your model:

import pickle

import arviz as az
import bambi as bmb

df = bmb.load_data("my_data")
df.head()

# Store all the arguments you pass to `bmb.Model()` in a dict that is pickled
family = "gaussian"
formula = "y ~ x + z"
priors = {
    "Intercept": bmb.Prior("Normal", mu=0.5, sigma=1),
    "x": bmb.Prior("Normal", mu=0, sigma=1),
    "z": bmb.Prior("Normal", mu=0, sigma=2),
}

args_dict = {
    "formula": formula,
    "data": df,
    "family": family,
    "priors": priors
}

# Create and fit model
model = bmb.Model(**args_dict)
idata = model.fit(random_seed=1234)

# Store things on disk
# Model metadata (required to re-create the model)
with open("model_args_dict.pickle", "wb") as handle:
    pickle.dump(args_dict, handle)

# InferenceData object (contains draws from the posterior)
idata.to_netcdf("idata.nc")

Script 2 This is what you use to obtain predictions on new datasets without having to build/fit the underlying PyMC model again

import pickle

import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd

# This is a new data frame
df_new = pd.DataFrame({"x": np.random.normal(size=10), "z": np.random.normal(size=10)})

# Load original arguments
with open("model_args_dict.pickle", "rb") as handle:
    args_dict_loaded = pickle.load(handle)

# Re-create the Bambi model (but this doesn't recreate the PyMC model unless you .build() it)
model_loaded = bmb.Model(**args_dict_loaded)

# Load the posterior draws (and other data too)
idata_loaded = az.from_netcdf("idata.nc")

# Use the model to predict on the new dataset
model_loaded.predict(idata_loaded, data=df_new, inplace=False, kind="pps")

@tomicapretto
Copy link
Collaborator

@GStechschulte I'm thinking making this pattern more visible on our docs could help more people, what do you think?

Also, it's actually quite fast (as it doesn't have to compile many things on the PyMC side)

@humana
Copy link

humana commented Apr 9, 2024

Thank you, this is pretty much what I came up with myself after reading through what the predict function would need, but I was worried I might have missed something because it looked too simple. Very helpful.

@tomicapretto
Copy link
Collaborator

Thank you, this is pretty much what I came up with myself after reading through what the predict function would need, but I was worried I might have missed something because it looked too simple. Very helpful.

Great! This is possible because Bambi "knows" how to compute a lot of things without relying on PyMC/PyTensor graph structure. If we have a Bambi model and the inference data object, we can generate predictions without having to build the PyMC model at all. However, we do use the draw() function from PyMC to get the draws from a PyMC distribution. We could avoid this step in many cases, but we would need to maintain a larger and more confusing codebase.

I'm happy you were able to work it out. Just let me know if you have any other question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants