Skip to content

Commit

Permalink
Update Darts to 0.29 and add TSMixer (#64)
Browse files Browse the repository at this point in the history
* Update darts from 0.28 to 0.29

* Initial addition of TSMixerModel

* Remove blank hyperparm bug

* No bugs with old log_tensorboard param name

* Fix has propagated through statsforecast and darts
  • Loading branch information
mepland committed May 12, 2024
1 parent 3e78679 commit dfbae61
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 44 deletions.
92 changes: 92 additions & 0 deletions TSModelWrappers/TSMixerModelWrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# pylint: disable=invalid-name,duplicate-code
"""Wrapper for TSMixer."""
# pylint: enable=invalid-name

from typing import Any

from darts.models import TSMixerModel

from TSModelWrappers.TSModelWrapper import (
DATA_FIXED_HYPERPARAMS,
DATA_REQUIRED_HYPERPARAMS,
DATA_VARIABLE_HYPERPARAMS,
NN_ALLOWED_VARIABLE_HYPERPARAMS,
NN_FIXED_HYPERPARAMS,
NN_REQUIRED_HYPERPARAMS,
TSModelWrapper,
)

__all__ = ["TSMixerModelWrapper"]


class TSMixerModelWrapper(TSModelWrapper):
"""TSMixerModel wrapper.
https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tsmixer_model.html
"""

# config wrapper for TSMixerModel
_model_class = TSMixerModel
_model_type = "torch"
_required_hyperparams_data = DATA_REQUIRED_HYPERPARAMS
_required_hyperparams_model = NN_REQUIRED_HYPERPARAMS + [
"hidden_size",
"ff_size",
"num_blocks",
"normalize_before",
]
_allowed_variable_hyperparams = {**DATA_VARIABLE_HYPERPARAMS, **NN_ALLOWED_VARIABLE_HYPERPARAMS}
_fixed_hyperparams = {**DATA_FIXED_HYPERPARAMS, **NN_FIXED_HYPERPARAMS}

# leave the following hyperparameters at their default values:
# output_chunk_shift ~ 0
# use_static_covariates ~ True
# norm_type ~ LayerNorm
# activation ~ ReLU

def __init__(self: "TSMixerModelWrapper", **kwargs: Any) -> None: # noqa: ANN401
# boilerplate - the same for all models below here
# NOTE using `isinstance(kwargs["TSModelWrapper"], TSModelWrapper)`,
# or even `issubclass(type(kwargs["TSModelWrapper"]), TSModelWrapper)` would be preferable
# but they do not work if the kwargs["TSModelWrapper"] parent instance was updated between child __init__ calls
if (
"TSModelWrapper" in kwargs
and type( # noqa: E721 # pylint: disable=unidiomatic-typecheck
kwargs["TSModelWrapper"].__class__
)
== type(TSModelWrapper) # <class 'type'>
and str(kwargs["TSModelWrapper"].__class__)
== str(TSModelWrapper) # <class 'TSModelWrappers.TSModelWrappers.TSModelWrapper'>
):
self.__dict__ = kwargs["TSModelWrapper"].__dict__.copy()
self.model_class = self._model_class
self.model_type = self._model_type
self.verbose = kwargs.get("verbose", 1)
self.work_dir = kwargs.get("work_dir")
self.model_name_tag = kwargs.get("model_name_tag")
self.required_hyperparams_data = self._required_hyperparams_data
self.required_hyperparams_model = self._required_hyperparams_model
self.allowed_variable_hyperparams = self._allowed_variable_hyperparams
self.variable_hyperparams = kwargs.get("variable_hyperparams", {})
self.fixed_hyperparams = self._fixed_hyperparams
else:
super().__init__(
dfp_trainable_evergreen=kwargs["dfp_trainable_evergreen"],
dt_val_start_datetime_local=kwargs["dt_val_start_datetime_local"],
work_dir_base=kwargs["work_dir_base"],
random_state=kwargs["random_state"],
date_fmt=kwargs["date_fmt"],
time_fmt=kwargs["time_fmt"],
fname_datetime_fmt=kwargs["fname_datetime_fmt"],
local_timezone=kwargs["local_timezone"],
model_class=self._model_class,
model_type=self._model_type,
verbose=kwargs.get("verbose", 1),
work_dir=kwargs["work_dir"],
model_name_tag=kwargs.get("model_name_tag"),
required_hyperparams_data=self._required_hyperparams_data,
required_hyperparams_model=self._required_hyperparams_model,
allowed_variable_hyperparams=self._allowed_variable_hyperparams,
variable_hyperparams=kwargs.get("variable_hyperparams"),
fixed_hyperparams=self._fixed_hyperparams,
)
26 changes: 16 additions & 10 deletions TSModelWrappers/TSModelWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,7 @@
import torch
import torchmetrics
from darts import TimeSeries

with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
# Reported in https://github.com/Nixtla/statsforecast/issues/781
# Fixed in https://github.com/Nixtla/statsforecast/pull/786
# Leaving warning filter as the patch needs to propagate through statsforecast and darts releases
from darts.models.forecasting.forecasting_model import ForecastingModel

from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.utils.callbacks import TFMProgressBar
from darts.utils.missing_values import fill_missing_values, missing_values_ratio
from darts.utils.utils import ModelMode, SeasonalityMode
Expand Down Expand Up @@ -356,7 +349,7 @@ def get_lr_scheduler_kwargs(lr_factor: float, lr_patience: int) -> dict:
"default": 30,
"type": int,
},
"num_blocks": {
"num_blocks": { # and TSMixerModel
"min": 1,
"max": 10,
"default": 1,
Expand Down Expand Up @@ -443,7 +436,7 @@ def get_lr_scheduler_kwargs(lr_factor: float, lr_patience: int) -> dict:
"type": int,
},
# TFTModel hyperparams
"hidden_size": { # and TiDEModel
"hidden_size": { # and TiDEModel, TSMixerModel
"min": 1,
"max": 256,
"default": 16,
Expand Down Expand Up @@ -473,6 +466,19 @@ def get_lr_scheduler_kwargs(lr_factor: float, lr_patience: int) -> dict:
"default": 8,
"type": int,
},
# TSMixerModel hyperparams
"ff_size": {
"min": 1,
"max": 256,
"default": 64,
"type": int,
},
"normalize_before": {
"min": 0,
"max": 1,
"default": 0,
"type": bool,
},
# DLinearModel and NLinearModel hyperparams
"const_init": {
"min": 0,
Expand Down
58 changes: 30 additions & 28 deletions ana/drive_bayesian_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from TSModelWrappers.TCNModelWrapper import TCNModelWrapper
from TSModelWrappers.TransformerModelWrapper import TransformerModelWrapper
from TSModelWrappers.TFTModelWrapper import TFTModelWrapper
from TSModelWrappers.TSMixerModelWrapper import TSMixerModelWrapper
from TSModelWrappers.DLinearModelWrapper import DLinearModelWrapper
from TSModelWrappers.NLinearModelWrapper import NLinearModelWrapper
from TSModelWrappers.TiDEModelWrapper import TiDEModelWrapper
Expand Down Expand Up @@ -126,64 +127,65 @@ def drive_bayesian_opt(
{"model_wrapper_class": TCNModelWrapper}, # +i_model=3
{"model_wrapper_class": TransformerModelWrapper}, # +i_model=4
{"model_wrapper_class": TFTModelWrapper}, # +i_model=5
{"model_wrapper_class": DLinearModelWrapper}, # +i_model=6
{"model_wrapper_class": NLinearModelWrapper}, # +i_model=7
{"model_wrapper_class": TiDEModelWrapper}, # +i_model=8
{"model_wrapper_class": TSMixerModelWrapper}, # +i_model=6
{"model_wrapper_class": DLinearModelWrapper}, # +i_model=7
{"model_wrapper_class": NLinearModelWrapper}, # +i_model=8
{"model_wrapper_class": TiDEModelWrapper}, # +i_model=9
{
"model_wrapper_class": RNNModelWrapper,
"model_wrapper_kwargs": {"model": "RNN"},
}, # +i_model=9
}, # +i_model=10
{
"model_wrapper_class": RNNModelWrapper,
"model_wrapper_kwargs": {"model": "LSTM"},
}, # +i_model=10
}, # +i_model=11
{
"model_wrapper_class": RNNModelWrapper,
"model_wrapper_kwargs": {"model": "GRU"},
}, # +i_model=11
}, # +i_model=12
{
"model_wrapper_class": BlockRNNModelWrapper,
"model_wrapper_kwargs": {"model": "RNN"},
}, # +i_model=12
}, # +i_model=13
{
"model_wrapper_class": BlockRNNModelWrapper,
"model_wrapper_kwargs": {"model": "LSTM"},
}, # +i_model=13
}, # +i_model=14
{
"model_wrapper_class": BlockRNNModelWrapper,
"model_wrapper_kwargs": {"model": "GRU"},
}, # +i_model=14
}, # +i_model=15
# Statistical Models
{"model_wrapper_class": AutoARIMAWrapper}, # +i_model=15
{"model_wrapper_class": BATSWrapper}, # +i_model=16
{"model_wrapper_class": TBATSWrapper}, # +i_model=17
{"model_wrapper_class": FourThetaWrapper}, # +i_model=18
{"model_wrapper_class": StatsForecastAutoThetaWrapper}, # +i_model=19
{"model_wrapper_class": FFTWrapper}, # +i_model=20
{"model_wrapper_class": KalmanForecasterWrapper}, # +i_model=21
{"model_wrapper_class": AutoARIMAWrapper}, # +i_model=16
{"model_wrapper_class": BATSWrapper}, # +i_model=17
{"model_wrapper_class": TBATSWrapper}, # +i_model=18
{"model_wrapper_class": FourThetaWrapper}, # +i_model=19
{"model_wrapper_class": StatsForecastAutoThetaWrapper}, # +i_model=20
{"model_wrapper_class": FFTWrapper}, # +i_model=21
{"model_wrapper_class": KalmanForecasterWrapper}, # +i_model=22
{
"model_wrapper_class": CrostonWrapper,
"model_wrapper_kwargs": {"version": "optimized"},
}, # +i_model=22
}, # +i_model=23
{
"model_wrapper_class": CrostonWrapper,
"model_wrapper_kwargs": {"version": "classic"},
}, # +i_model=23
}, # +i_model=24
{
"model_wrapper_class": CrostonWrapper,
"model_wrapper_kwargs": {"version": "sba"},
}, # +i_model=24
}, # +i_model=25
# Regression Models
{"model_wrapper_class": LinearRegressionModelWrapper}, # +i_model=25
{"model_wrapper_class": RandomForestWrapper}, # +i_model=26
{"model_wrapper_class": LightGBMModelWrapper}, # +i_model=27
{"model_wrapper_class": XGBModelWrapper}, # +i_model=28
{"model_wrapper_class": CatBoostModelWrapper}, # +i_model=29
{"model_wrapper_class": LinearRegressionModelWrapper}, # +i_model=26
{"model_wrapper_class": RandomForestWrapper}, # +i_model=27
{"model_wrapper_class": LightGBMModelWrapper}, # +i_model=28
{"model_wrapper_class": XGBModelWrapper}, # +i_model=29
{"model_wrapper_class": CatBoostModelWrapper}, # +i_model=30
# Naive Models
{"model_wrapper_class": NaiveMeanWrapper}, # +i_model=30
{"model_wrapper_class": NaiveSeasonalWrapper}, # +i_model=31
{"model_wrapper_class": NaiveDriftWrapper}, # +i_model=32
{"model_wrapper_class": NaiveMovingAverageWrapper}, # +i_model=33
{"model_wrapper_class": NaiveMeanWrapper}, # +i_model=31
{"model_wrapper_class": NaiveSeasonalWrapper}, # +i_model=32
{"model_wrapper_class": NaiveDriftWrapper}, # +i_model=33
{"model_wrapper_class": NaiveMovingAverageWrapper}, # +i_model=34
]

# accept i_model CLI argument to only run one model
Expand Down
32 changes: 32 additions & 0 deletions ana/exploratory_ana.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from TSModelWrappers.TCNModelWrapper import TCNModelWrapper
from TSModelWrappers.TransformerModelWrapper import TransformerModelWrapper
from TSModelWrappers.TFTModelWrapper import TFTModelWrapper
from TSModelWrappers.TSMixerModelWrapper import TSMixerModelWrapper
from TSModelWrappers.DLinearModelWrapper import DLinearModelWrapper
from TSModelWrappers.NLinearModelWrapper import NLinearModelWrapper
from TSModelWrappers.TiDEModelWrapper import TiDEModelWrapper
Expand Down Expand Up @@ -740,6 +741,37 @@ def display_image(fname: pathlib.Path, *, plot_inline: bool = PLOT_INLINE) -> No
# %%
# # %tensorboard --logdir $tensorboard_logs

# %% [markdown]
# ### TSMixer

# %%
model_wrapper_TSMixer = TSMixerModelWrapper(
TSModelWrapper=PARENT_WRAPPER,
variable_hyperparams={"time_bin_size_in_minutes": 10},
)
model_wrapper_TSMixer.set_work_dir(work_dir_relative_to_base=pathlib.Path("local_dev"))
# print(model_wrapper_TSMixer)

# %%
configurable_hyperparams = model_wrapper_TSMixer.get_configurable_hyperparams()
pprint.pprint(configurable_hyperparams)

# %%
loss_val, metrics_val = model_wrapper_TSMixer.train_model()
print(f"metrics_val = {pprint.pformat(metrics_val)}")

# %%
print(model_wrapper_TSMixer)

# %%
tensorboard_logs = pathlib.Path(
model_wrapper_TSMixer.work_dir, model_wrapper_TSMixer.model_name, "logs" # type: ignore[arg-type]
)
print(tensorboard_logs)

# %%
# # %tensorboard --logdir $tensorboard_logs

# %% [markdown]
# ### D-Linear

Expand Down
2 changes: 1 addition & 1 deletion ana/start_bayesian_opt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pkg_path="/home/mepland/chance_of_showers"
# Models to run, see model_kwarg_list in drive_bayesian_opt.py
# prod
i_model_min=0
i_model_max=33
i_model_max=34
# dev
# i_model_min=30 # 0, 3, 20, 25, 30
# i_model_max=$i_model_min
Expand Down
10 changes: 6 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ torchvision = { version = "^0.17.1+cu118", source = "pytorch-gpu-src" }
torchaudio = { version = "^2.2.1+cu118", source = "pytorch-gpu-src" }
catboost = "^1.2.3"
lightgbm = "^4.3.0"
darts = "^0.28.0"
darts = "^0.29.0"
tensorboard = "^2.16.2"
bayesian-optimization = "^1.4.3"
xlsxwriter = "^3.2.0"
Expand Down
2 changes: 2 additions & 0 deletions utils/bayesian_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from TSModelWrappers.TCNModelWrapper import TCNModelWrapper # noqa: TC001
from TSModelWrappers.TransformerModelWrapper import TransformerModelWrapper # noqa: TC001
from TSModelWrappers.TFTModelWrapper import TFTModelWrapper # noqa: TC001
from TSModelWrappers.TSMixerModelWrapper import TSMixerModelWrapper # noqa: TC001
from TSModelWrappers.DLinearModelWrapper import DLinearModelWrapper # noqa: TC001
from TSModelWrappers.NLinearModelWrapper import NLinearModelWrapper # noqa: TC001
from TSModelWrappers.TiDEModelWrapper import TiDEModelWrapper # noqa: TC001
Expand Down Expand Up @@ -90,6 +91,7 @@
| TCNModelWrapper
| TransformerModelWrapper
| TFTModelWrapper
| TSMixerModelWrapper
| DLinearModelWrapper
| NLinearModelWrapper
| TiDEModelWrapper
Expand Down

0 comments on commit dfbae61

Please sign in to comment.