Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to generate forecasts with prediction_length > 64? #40

Open
clevilll opened this issue Apr 3, 2024 · 4 comments
Open

How to generate forecasts with prediction_length > 64? #40

clevilll opened this issue Apr 3, 2024 · 4 comments
Labels
FAQ Frequently asked question

Comments

@clevilll
Copy link

clevilll commented Apr 3, 2024

Hi,

I have time data and split to train and test (keep it unseen) by slicing the df from the end part. I used your pipeline over data_train and tried to forecast as length as data_test unsuccessfully as below :

#-----------------------------------------------------------
# Libs
#-----------------------------------------------------------
# for plotting, run: pip install pandas matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline

#-----------------------------------------------------------
# LOAD THE DATASET
#-----------------------------------------------------------

df = pd.read_csv('https://raw.githubusercontent.com/amcs1729/Predicting-cloud-CPU-usage-on-Azure-data/master/azure.csv')
df['timestamp'] =  pd.to_datetime(df['timestamp'])
data = df.rename(columns={'min cpu': 'min_cpu',
                          'max cpu': 'max_cpu',
                          'avg cpu': 'avg_cpu',})



# Data preparation
# ==============================================================================
sliced_df = data[['timestamp', 'avg_cpu']]

# Convert data from Hz to MHz
# ==============================================================================
sliced_df['avg_cpu_Mhz'] = sliced_df['avg_cpu'] / 1000000
sliced_df

# Configuration
# ==============================================================================
name_columns='avg_cpu_Mhz'
lags=288
steps=288
n_backtest=3

step_size = steps * n_backtest
data_train = sliced_df[:-step_size]
data_test  = sliced_df[-step_size:] #unseen

# Pipeline
# ==============================================================================
pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(data_train['avg_cpu_Mhz'])
prediction_length = 64 #len(data_test) #12

forecast = pipeline.predict(
    context,
    prediction_length,
    num_samples=288, #20,
    temperature=1.0,
    top_k=50,
    top_p=1.0,
) # forecast shape: [num_series, num_samples, prediction_length]

but results is as follow:

# visualize the forecast
forecast_index = range(len(data_train), len(data_train) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(data_train['avg_cpu_Mhz'], color="royalblue", label="historical train data")
plt.plot(data_test['avg_cpu_Mhz'] , color="navy",      label="historical test data", linestyle='dashed')
plt.plot(forecast_index, median,    color="tomato",    label="median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")

plt.title('Chronos forecast result')
plt.ylabel(' CPU usage [MHz]',   fontsize=15)
plt.xlabel('Timestamp', fontsize=15)
plt.legend()
plt.grid()
plt.show()

img

  • How I can configure the arguments within predict() class to have forecast autoregressive over unseen data_test ?
  • why prediction_length recommended to be <=64 ?
@abdulfatir
Copy link
Contributor

abdulfatir commented Apr 4, 2024

  • You can set limit_prediction_length=False in predict(). See here:
    limit_prediction_length: bool = True,
  • The prediction_length is recommended to be <=64 because the models were trained for predictions upto 64. Unrolling the model beyond that may lead to suboptimal results.
  • Since it looks like you're using a high frequency time series (5min), there's another important point to note: the model only uses a context of the last 512 steps which may not be enough to correctly capture the seasonal patterns of a high freq. time series. We have discussed this briefly in the paper in Sec. 5.6 (Context Length).

@abdulfatir abdulfatir changed the title How one can configure the arguments within predict() class to have forecast autoregressive over unseen data_test while prediction_length recommended to be <=64 ? How to generate forecasts with prediction_length >= 64? Apr 4, 2024
@abdulfatir abdulfatir changed the title How to generate forecasts with prediction_length >= 64? How to generate forecasts with prediction_length > 64? Apr 4, 2024
@abdulfatir abdulfatir added the FAQ Frequently asked question label Apr 4, 2024
@abdulfatir
Copy link
Contributor

Alternatively, you can resample your dataset to a lower frequency. Here's an example with 1H:

#-----------------------------------------------------------
# Libs
#-----------------------------------------------------------
# for plotting, run: pip install pandas matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline

#-----------------------------------------------------------
# LOAD THE DATASET
#-----------------------------------------------------------

df = pd.read_csv('https://raw.githubusercontent.com/amcs1729/Predicting-cloud-CPU-usage-on-Azure-data/master/azure.csv')
df['timestamp'] =  pd.to_datetime(df['timestamp'])
data = df.rename(columns={'min cpu': 'min_cpu',
                          'max cpu': 'max_cpu',
                          'avg cpu': 'avg_cpu',})



# Data preparation
# ==============================================================================
sliced_df = data[['timestamp', 'avg_cpu']]

# Convert data from Hz to MHz
# ==============================================================================
sliced_df['avg_cpu_Mhz'] = sliced_df['avg_cpu'] / 1000000
sliced_df = sliced_df.set_index("timestamp").resample("1H").sum().reset_index()

# Configuration
# ==============================================================================
name_columns='avg_cpu_Mhz'
lags=24
steps=24
n_backtest=3

step_size = steps * n_backtest
data_train = sliced_df[:-step_size]
data_test  = sliced_df[-step_size:] #unseen

# Pipeline
# ==============================================================================
pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(data_train['avg_cpu_Mhz'])
prediction_length = 72 #len(data_test) #12

forecast = pipeline.predict(
    context,
    prediction_length,
    num_samples=20, #20,
    temperature=1.0,
    top_k=50,
    top_p=1.0,
    limit_prediction_length=False
) # forecast shape: [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(data_train), len(data_train) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(data_train['avg_cpu_Mhz'], color="royalblue", label="historical train data")
plt.plot(data_test['avg_cpu_Mhz'] , color="navy",      label="historical test data", linestyle='dashed')
plt.plot(forecast_index, median,    color="tomato",    label="median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")

plt.title('Chronos forecast result')
plt.ylabel(' CPU usage [MHz]',   fontsize=15)
plt.xlabel('Timestamp', fontsize=15)
plt.legend()
plt.grid()
plt.show()

Result:
image

@clevilll
Copy link
Author

clevilll commented Apr 4, 2024

@abdulfatir Thanks for your answer.

I have few Qs

  1. Can we conclude that one of the shortcomings of chronos is the length of prediction in out-of-sample forecasting over high-frequency time-series data?

Therefore to avoid suboptimal results we need ressample it. However, the fact that sometimes resampling with certain aggregation functions can damage the nature of time data. in this case, nature kept almost when you did df.set_index("timestamp").resample("1H").sum().reset_index() with aggregation function of sum().

  1. What is the best practice based on your learning if one needs to resample without (with minimum) damaging the geometry and nature of time data? (by resampling we lose information kind of)

@abdulfatir
Copy link
Contributor

@clevilll sorry, missed this.

  1. Yes and No. It's not a limitation for the Chronos framework per se but the current Chronos models which were training to only look at a maximum context of 512 steps and forecast 64 steps into the future.
  2. Indeed subsampling will result in loss of information. Whether subsamlping is an acceptable option heavily depends on the use case and in my view there's no one-fits-all solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FAQ Frequently asked question
Projects
None yet
Development

No branches or pull requests

2 participants