Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance plot #1094

Open
lcrmorin opened this issue Apr 26, 2024 · 0 comments
Open

Enhance plot #1094

lcrmorin opened this issue Apr 26, 2024 · 0 comments

Comments

@lcrmorin
Copy link

Describe the workflow you want to enable

Improve plot_splits for time series splits.

Currently the plot present some limitation. Here is an exemple with code:

import pandas as pd, numpy as np
import seaborn as sns, matplotlib.pyplot as plt

from sklearn.datasets import make_regression
from sklearn.dummy import DummyRegressor
from sklearn.metrics import mean_squared_error, make_scorer
from sklearn.model_selection import cross_val_score

from mlxtend.evaluate.time_series import GroupTimeSeriesSplit, plot_splits

X_test, y_test = [], []

start_year = 2010
end_year = 2020

for year in np.arange(start_year, end_year+1):
    X_year, y_year = make_regression(n_samples=5+(year-start_year), n_features=2, bias=0, noise=1, random_state=year)
    X_year = pd.DataFrame(X_year).rename(columns={0:'X1', 1:'X2'})
    X_year['year'] = year
    y_year = pd.Series(y_year)
    X_test.append(X_year)
    y_test.append(y_year)

X, y = pd.concat(X_test), pd.concat(y_test)

# modelisation
model = DummyRegressor(strategy="mean")
metric = mean_squared_error
cv_args = {"test_size": 1, 'n_splits': len(np.unique(X['year'])) - 1, 'window_type': 'expanding'}
cv = GroupTimeSeriesSplit(**cv_args)

scores = cross_val_score(model, X, y, cv=cv, groups=X['year'], scoring=make_scorer(metric))

plot_splits(X, y, X['year'], **cv_args)

gives the following plot:

9yb86cKN

As you can notice:

  • seems like there is an off by one error on the index, so that red bars appears slightly overlapping. The values are not actually overlapping, but the plots are.
  • index get crowded very fast
  • lack of details of groups

Describe your proposed solution

It might be a good idea to:

  • correct off by 1 error
  • remove index
  • display groups

one option would be to only plot group with constant size:

PPWS0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant