Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

week2 pickle model for Lasso #201

Open
lucapug opened this issue May 26, 2023 · 0 comments
Open

week2 pickle model for Lasso #201

lucapug opened this issue May 26, 2023 · 0 comments

Comments

@lucapug
Copy link

lucapug commented May 26, 2023

with reference to the week 2 modified version of the duration-prediction.ipynb (link)

I think that the pickle model that is log_artifact() in the last line of the following block of code is the wrong one, because the lin_reg.bin model is the one saved outside the mlflow run (the one obtained by fitting the linear regression model without regularization) and this model is different from the one fitted inside the experiment run (that is a Lasso model)

lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)
7.758715210382775

with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

with mlflow.start_run():

    mlflow.set_tag("developer", "cristian")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant