forked from ngruver/llmtime
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from reichlab/bsweger/labmeeting-organization
Bsweger/labmeeting organization
- Loading branch information
Showing
13 changed files
with
568 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# put your openai key here and rename the file from .env-example to .env | ||
OPENAI_API_KEY=pasteyouropenaikeyhere | ||
MISTRAL_KEY=pasteyourmistralkeyhere |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.env | ||
.envrc | ||
|
||
**/__pycache__ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import os | ||
import torch | ||
os.environ['OMP_NUM_THREADS'] = '4' | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import openai | ||
openai.api_key = os.environ['OPENAI_API_KEY'] | ||
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") | ||
from data.serialize import SerializerSettings | ||
from models.utils import grid_iter | ||
from models.promptcast import get_promptcast_predictions_data | ||
from models.darts import get_arima_predictions_data | ||
from models.llmtime import get_llmtime_predictions_data | ||
from data.small_context import get_datasets | ||
from models.validation_likelihood_tuning import get_autotuned_predictions_data | ||
|
||
def plot_preds(train, test, pred_dict, model_name, show_samples=False): | ||
pred = pred_dict['median'] | ||
pred = pd.Series(pred, index=test.index) | ||
plt.figure(figsize=(8, 6), dpi=100) | ||
plt.plot(train) | ||
plt.plot(test, label='Truth', color='black') | ||
plt.plot(pred, label=model_name, color='purple') | ||
# shade 90% confidence interval | ||
samples = pred_dict['samples'] | ||
lower = np.quantile(samples, 0.05, axis=0) | ||
upper = np.quantile(samples, 0.95, axis=0) | ||
plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') | ||
if show_samples: | ||
samples = pred_dict['samples'] | ||
# convert df to numpy array | ||
samples = samples.values if isinstance(samples, pd.DataFrame) else samples | ||
for i in range(min(10, samples.shape[0])): | ||
plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) | ||
plt.legend(loc='upper left') | ||
if 'NLL/D' in pred_dict: | ||
nll = pred_dict['NLL/D'] | ||
if nll is not None: | ||
plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5)) | ||
plt.show() | ||
|
||
|
||
|
||
print(torch.cuda.max_memory_allocated()) | ||
print() | ||
|
||
gpt4_hypers = dict( | ||
alpha=0.3, | ||
basic=True, | ||
temp=1.0, | ||
top_p=0.8, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') | ||
) | ||
|
||
mistral_api_hypers = dict( | ||
alpha=0.3, | ||
basic=True, | ||
temp=1.0, | ||
top_p=0.8, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') | ||
) | ||
|
||
gpt3_hypers = dict( | ||
temp=0.7, | ||
alpha=0.95, | ||
beta=0.3, | ||
basic=False, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) | ||
) | ||
|
||
|
||
llma2_hypers = dict( | ||
temp=0.7, | ||
alpha=0.95, | ||
beta=0.3, | ||
basic=False, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) | ||
) | ||
|
||
|
||
promptcast_hypers = dict( | ||
temp=0.7, | ||
settings=SerializerSettings(base=10, prec=0, signed=True, | ||
time_sep=', ', | ||
bit_sep='', | ||
plus_sign='', | ||
minus_sign='-', | ||
half_bin_correction=False, | ||
decimal_point='') | ||
) | ||
|
||
arima_hypers = dict(p=[12,30], d=[1,2], q=[0]) | ||
|
||
model_hypers = { | ||
'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers}, | ||
'LLMTime GPT-4': {'model': 'gpt-4', **gpt4_hypers}, | ||
'LLMTime GPT-3': {'model': 'text-davinci-003', **gpt3_hypers}, | ||
'PromptCast GPT-3': {'model': 'text-davinci-003', **promptcast_hypers}, | ||
'LLMA2': {'model': 'llama-7b', **llma2_hypers}, | ||
'mistral': {'model': 'mistral', **llma2_hypers}, | ||
'mistral-api-tiny': {'model': 'mistral-api-tiny', **mistral_api_hypers}, | ||
'mistral-api-small': {'model': 'mistral-api-tiny', **mistral_api_hypers}, | ||
'mistral-api-medium': {'model': 'mistral-api-tiny', **mistral_api_hypers}, | ||
'ARIMA': arima_hypers, | ||
|
||
} | ||
|
||
|
||
model_predict_fns = { | ||
#'LLMA2': get_llmtime_predictions_data, | ||
#'mistral': get_llmtime_predictions_data, | ||
#'LLMTime GPT-4': get_llmtime_predictions_data, | ||
'mistral-api-tiny': get_llmtime_predictions_data | ||
} | ||
|
||
|
||
model_names = list(model_predict_fns.keys()) | ||
|
||
datasets = get_datasets() | ||
ds_name = 'AirPassengersDataset' | ||
|
||
|
||
data = datasets[ds_name] | ||
train, test = data # or change to your own data | ||
out = {} | ||
|
||
for model in model_names: # GPT-4 takes a about a minute to run | ||
model_hypers[model].update({'dataset_name': ds_name}) # for promptcast | ||
hypers = list(grid_iter(model_hypers[model])) | ||
num_samples = 10 | ||
pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], verbose=False, parallel=False) | ||
out[model] = pred_dict | ||
plot_preds(train, test, pred_dict, model, show_samples=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
os.environ['OMP_NUM_THREADS'] = '4' | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import openai | ||
openai.api_key = os.environ['OPENAI_API_KEY'] | ||
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") | ||
from data.serialize import SerializerSettings | ||
from models.utils import grid_iter | ||
from models.promptcast import get_promptcast_predictions_data | ||
from models.darts import get_arima_predictions_data | ||
from models.llmtime import get_llmtime_predictions_data | ||
from data.small_context import get_datasets | ||
from models.validation_likelihood_tuning import get_autotuned_predictions_data | ||
|
||
# get OPENAI info from environment | ||
# if python-dotenv is installed, try loading from .env first | ||
try: | ||
from dotenv import load_dotenv | ||
load_dotenv(override=True) | ||
except ImportError: | ||
print('python-dotenv not installed, not loading .env file') | ||
openai.api_key = os.environ['OPENAI_API_KEY'] | ||
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") | ||
|
||
def plot_preds(train, test, pred_dict, model_name, show_samples=False): | ||
pred = pred_dict['median'] | ||
pred = pd.Series(pred, index=test.index) | ||
plt.figure(figsize=(8, 6), dpi=100) | ||
plt.plot(train) | ||
plt.plot(test, label='Truth', color='black') | ||
plt.plot(pred, label=model_name, color='purple') | ||
# shade 90% confidence interval | ||
samples = pred_dict['samples'] | ||
lower = np.quantile(samples, 0.05, axis=0) | ||
upper = np.quantile(samples, 0.95, axis=0) | ||
plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') | ||
if show_samples: | ||
samples = pred_dict['samples'] | ||
# convert df to numpy array | ||
samples = samples.values if isinstance(samples, pd.DataFrame) else samples | ||
for i in range(min(10, samples.shape[0])): | ||
plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) | ||
plt.legend(loc='upper left') | ||
if 'NLL/D' in pred_dict: | ||
nll = pred_dict['NLL/D'] | ||
if nll is not None: | ||
plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5)) | ||
plt.show(block=False) | ||
|
||
|
||
# DEFINE MODELS | ||
gpt4_hypers = dict( | ||
alpha=0.3, | ||
basic=True, | ||
temp=1.0, | ||
top_p=0.8, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') | ||
) | ||
|
||
gpt3_hypers = dict( | ||
temp=0.7, | ||
alpha=0.95, | ||
beta=0.3, | ||
basic=False, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) | ||
) | ||
|
||
|
||
promptcast_hypers = dict( | ||
temp=0.7, | ||
settings=SerializerSettings(base=10, prec=0, signed=True, | ||
time_sep=', ', | ||
bit_sep='', | ||
plus_sign='', | ||
minus_sign='-', | ||
half_bin_correction=False, | ||
decimal_point='') | ||
) | ||
|
||
arima_hypers = dict(p=[12,30], d=[1,2], q=[0]) | ||
|
||
model_hypers = { | ||
'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers}, | ||
'PromptCast GPT-3': {'model': 'gpt-3.5-turbo-instruct', **promptcast_hypers}, | ||
'ARIMA': arima_hypers, | ||
|
||
} | ||
|
||
model_predict_fns = { | ||
'LLMTime GPT-3.5': get_llmtime_predictions_data, | ||
'PromptCast GPT-3': get_promptcast_predictions_data, | ||
'ARIMA': get_arima_predictions_data, | ||
} | ||
|
||
model_names = list(model_predict_fns.keys()) | ||
|
||
# RUN LLMTIME AND VISUALIZE RESULTS | ||
datasets = get_datasets() | ||
ds_name = 'AirPassengersDataset' | ||
|
||
data = datasets[ds_name] | ||
train, test = data # or change to your own data | ||
out = {} | ||
for model in model_names: # GPT-4 takes a about a minute to run | ||
model_hypers[model].update({'dataset_name': ds_name}) # for promptcast | ||
hypers = list(grid_iter(model_hypers[model])) | ||
num_samples = 10 | ||
pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], verbose=False, parallel=False) | ||
out[model] = pred_dict | ||
plot_preds(train, test, pred_dict, model, show_samples=True) | ||
|
||
# Keep all plot windows open | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.