diff --git a/.env-example b/.env-example new file mode 100644 index 0000000..82966c5 --- /dev/null +++ b/.env-example @@ -0,0 +1,3 @@ +# put your openai key here and rename the file from .env-example to .env +OPENAI_API_KEY=pasteyouropenaikeyhere +MISTRAL_KEY=pasteyourmistralkeyhere \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..73a2b6c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.env +.envrc + +**/__pycache__ + diff --git a/LICENSE b/LICENSE index 6a454af..564b92f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Nate Gruver +Copyright (c) 2023 Nate Gruver, Marc Finzi, Shikai Qiu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 7986048..d4682a3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ conda activate llmtime ``` If you prefer not using conda, you can also install the dependencies listed in `install.sh` manually. -Add your openai api key to `~/.bashrc` with +If you want to run OpenAI models through their API (doesn't require access to a GPU), add your openai api key to `~/.bashrc` with ``` echo "export OPENAI_API_KEY=" >> ~/.bashrc ``` @@ -34,10 +34,16 @@ echo "export OPENAI_API_BASE=" >> ~/.bashrc Want a quick taste of the power of LLMTime? Run the quick demo in the `demo.ipynb` notebook. No GPUs required! ## 🤖 Plugging in other LLMs -We currently support GPT-3, GPT-3.5, GPT-4, and LLaMA 2. It's easy to plug in other LLMs by simply specifying how to generate text completions from them in `models/llms.py`. +We currently support GPT-3, GPT-3.5, GPT-4, Mistral, and LLaMA 2. It's easy to plug in other LLMs by simply specifying how to generate text completions from them in `models/llms.py`. + +To run Mistral models, add your mistral api key to `~/.bashrc` with +``` +echo "export MISTRAL_KEY=" >> ~/.bashrc +``` ## 💡 Tips Here are some tips for using LLMTime: +- If you don't want to add OpenAI and Mistral keys to `~/.bashrc` or other dotfiles, you can add them to `.env-example` and rename the file to `.env` (which is in `.gitignore`). The demo code will add the contents of `.env` to your session's environment variables. - Performance is not too sensitive to the data scaling hyperparameters `alpha, beta, basic`. A good default is `alpha=0.95, beta=0.3, basic=False`. For data exhibiting symmetry around 0 (e.g. a sine wave), we recommend setting `basic=True` to avoid shifting the data. - The recently released `gpt-3.5-turbo-instruct` seems to require a lower temperature (e.g. 0.3) than other models, and tends to not outperform `text-davinci-003` from our limited experiments. - Tuning hyperparameters based on validation likelihoods, as done by `get_autotuned_predictions_data`, will often yield better test likelihoods, but won't necessarily yield better samples. diff --git a/demo.ipynb b/demo.ipynb index ff552b6..10aca7e 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -23,8 +23,15 @@ "from data.small_context import get_datasets\n", "from models.validation_likelihood_tuning import get_autotuned_predictions_data\n", "\n", - "%load_ext autoreload\n", - "%autoreload 2\n", + "# get OPENAI info from environment\n", + "# if python-dotenv is installed, try loading from .env first\n", + "try:\n", + " from dotenv import load_dotenv\n", + " load_dotenv(override=True)\n", + "except ImportError:\n", + " print('python-dotenv not installed, not loading .env file')\n", + "openai.api_key = os.environ['OPENAI_API_KEY']\n", + "openai.api_base = os.environ.get(\"OPENAI_API_BASE\", \"https://api.openai.com/v1\")\n", "\n", "def plot_preds(train, test, pred_dict, model_name, show_samples=False):\n", " pred = pred_dict['median']\n", @@ -99,16 +106,13 @@ "\n", "model_hypers = {\n", " 'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers},\n", - " 'LLMTime GPT-4': {'model': 'gpt-4', **gpt4_hypers},\n", - " 'LLMTime GPT-3': {'model': 'text-davinci-003', **gpt3_hypers},\n", - " 'PromptCast GPT-3': {'model': 'text-davinci-003', **promptcast_hypers},\n", + " 'PromptCast GPT-3': {'model': 'gpt-3.5-turbo-instruct', **promptcast_hypers},\n", " 'ARIMA': arima_hypers,\n", " \n", "}\n", "\n", "model_predict_fns = {\n", - " 'LLMTime GPT-3': get_llmtime_predictions_data,\n", - " 'LLMTime GPT-4': get_llmtime_predictions_data,\n", + " 'LLMTime GPT-3.5': get_llmtime_predictions_data,\n", " 'PromptCast GPT-3': get_promptcast_predictions_data,\n", " 'ARIMA': get_arima_predictions_data,\n", "}\n", diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..5526ca2 --- /dev/null +++ b/demo.py @@ -0,0 +1,134 @@ +import os +import torch +os.environ['OMP_NUM_THREADS'] = '4' +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import openai +openai.api_key = os.environ['OPENAI_API_KEY'] +openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") +from data.serialize import SerializerSettings +from models.utils import grid_iter +from models.promptcast import get_promptcast_predictions_data +from models.darts import get_arima_predictions_data +from models.llmtime import get_llmtime_predictions_data +from data.small_context import get_datasets +from models.validation_likelihood_tuning import get_autotuned_predictions_data + +def plot_preds(train, test, pred_dict, model_name, show_samples=False): + pred = pred_dict['median'] + pred = pd.Series(pred, index=test.index) + plt.figure(figsize=(8, 6), dpi=100) + plt.plot(train) + plt.plot(test, label='Truth', color='black') + plt.plot(pred, label=model_name, color='purple') + # shade 90% confidence interval + samples = pred_dict['samples'] + lower = np.quantile(samples, 0.05, axis=0) + upper = np.quantile(samples, 0.95, axis=0) + plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') + if show_samples: + samples = pred_dict['samples'] + # convert df to numpy array + samples = samples.values if isinstance(samples, pd.DataFrame) else samples + for i in range(min(10, samples.shape[0])): + plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) + plt.legend(loc='upper left') + if 'NLL/D' in pred_dict: + nll = pred_dict['NLL/D'] + if nll is not None: + plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5)) + plt.show() + + + +print(torch.cuda.max_memory_allocated()) +print() + +gpt4_hypers = dict( + alpha=0.3, + basic=True, + temp=1.0, + top_p=0.8, + settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') +) + +mistral_api_hypers = dict( + alpha=0.3, + basic=True, + temp=1.0, + top_p=0.8, + settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') +) + +gpt3_hypers = dict( + temp=0.7, + alpha=0.95, + beta=0.3, + basic=False, + settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) +) + + +llma2_hypers = dict( + temp=0.7, + alpha=0.95, + beta=0.3, + basic=False, + settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) +) + + +promptcast_hypers = dict( + temp=0.7, + settings=SerializerSettings(base=10, prec=0, signed=True, + time_sep=', ', + bit_sep='', + plus_sign='', + minus_sign='-', + half_bin_correction=False, + decimal_point='') +) + +arima_hypers = dict(p=[12,30], d=[1,2], q=[0]) + +model_hypers = { + 'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers}, + 'LLMTime GPT-4': {'model': 'gpt-4', **gpt4_hypers}, + 'LLMTime GPT-3': {'model': 'text-davinci-003', **gpt3_hypers}, + 'PromptCast GPT-3': {'model': 'text-davinci-003', **promptcast_hypers}, + 'LLMA2': {'model': 'llama-7b', **llma2_hypers}, + 'mistral': {'model': 'mistral', **llma2_hypers}, + 'mistral-api-tiny': {'model': 'mistral-api-tiny', **mistral_api_hypers}, + 'mistral-api-small': {'model': 'mistral-api-tiny', **mistral_api_hypers}, + 'mistral-api-medium': {'model': 'mistral-api-tiny', **mistral_api_hypers}, + 'ARIMA': arima_hypers, + + } + + +model_predict_fns = { + #'LLMA2': get_llmtime_predictions_data, + #'mistral': get_llmtime_predictions_data, + #'LLMTime GPT-4': get_llmtime_predictions_data, + 'mistral-api-tiny': get_llmtime_predictions_data +} + + +model_names = list(model_predict_fns.keys()) + +datasets = get_datasets() +ds_name = 'AirPassengersDataset' + + +data = datasets[ds_name] +train, test = data # or change to your own data +out = {} + +for model in model_names: # GPT-4 takes a about a minute to run + model_hypers[model].update({'dataset_name': ds_name}) # for promptcast + hypers = list(grid_iter(model_hypers[model])) + num_samples = 10 + pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], verbose=False, parallel=False) + out[model] = pred_dict + plot_preds(train, test, pred_dict, model, show_samples=True) \ No newline at end of file diff --git a/demo_openai.py b/demo_openai.py new file mode 100644 index 0000000..883ec47 --- /dev/null +++ b/demo_openai.py @@ -0,0 +1,115 @@ +import os +os.environ['OMP_NUM_THREADS'] = '4' +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import openai +openai.api_key = os.environ['OPENAI_API_KEY'] +openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") +from data.serialize import SerializerSettings +from models.utils import grid_iter +from models.promptcast import get_promptcast_predictions_data +from models.darts import get_arima_predictions_data +from models.llmtime import get_llmtime_predictions_data +from data.small_context import get_datasets +from models.validation_likelihood_tuning import get_autotuned_predictions_data + +# get OPENAI info from environment +# if python-dotenv is installed, try loading from .env first +try: + from dotenv import load_dotenv + load_dotenv(override=True) +except ImportError: + print('python-dotenv not installed, not loading .env file') +openai.api_key = os.environ['OPENAI_API_KEY'] +openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") + +def plot_preds(train, test, pred_dict, model_name, show_samples=False): + pred = pred_dict['median'] + pred = pd.Series(pred, index=test.index) + plt.figure(figsize=(8, 6), dpi=100) + plt.plot(train) + plt.plot(test, label='Truth', color='black') + plt.plot(pred, label=model_name, color='purple') + # shade 90% confidence interval + samples = pred_dict['samples'] + lower = np.quantile(samples, 0.05, axis=0) + upper = np.quantile(samples, 0.95, axis=0) + plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') + if show_samples: + samples = pred_dict['samples'] + # convert df to numpy array + samples = samples.values if isinstance(samples, pd.DataFrame) else samples + for i in range(min(10, samples.shape[0])): + plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) + plt.legend(loc='upper left') + if 'NLL/D' in pred_dict: + nll = pred_dict['NLL/D'] + if nll is not None: + plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5)) + plt.show(block=False) + + +# DEFINE MODELS +gpt4_hypers = dict( + alpha=0.3, + basic=True, + temp=1.0, + top_p=0.8, + settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') +) + +gpt3_hypers = dict( + temp=0.7, + alpha=0.95, + beta=0.3, + basic=False, + settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) +) + + +promptcast_hypers = dict( + temp=0.7, + settings=SerializerSettings(base=10, prec=0, signed=True, + time_sep=', ', + bit_sep='', + plus_sign='', + minus_sign='-', + half_bin_correction=False, + decimal_point='') +) + +arima_hypers = dict(p=[12,30], d=[1,2], q=[0]) + +model_hypers = { + 'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers}, + 'PromptCast GPT-3': {'model': 'gpt-3.5-turbo-instruct', **promptcast_hypers}, + 'ARIMA': arima_hypers, + +} + +model_predict_fns = { + 'LLMTime GPT-3.5': get_llmtime_predictions_data, + 'PromptCast GPT-3': get_promptcast_predictions_data, + 'ARIMA': get_arima_predictions_data, +} + +model_names = list(model_predict_fns.keys()) + +# RUN LLMTIME AND VISUALIZE RESULTS +datasets = get_datasets() +ds_name = 'AirPassengersDataset' + +data = datasets[ds_name] +train, test = data # or change to your own data +out = {} +for model in model_names: # GPT-4 takes a about a minute to run + model_hypers[model].update({'dataset_name': ds_name}) # for promptcast + hypers = list(grid_iter(model_hypers[model])) + num_samples = 10 + pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], verbose=False, parallel=False) + out[model] = pred_dict + plot_preds(train, test, pred_dict, model, show_samples=True) + +# Keep all plot windows open +plt.show() \ No newline at end of file diff --git a/install.sh b/install.sh index 5506f60..eb9ad23 100644 --- a/install.sh +++ b/install.sh @@ -1,7 +1,7 @@ conda create -n llmtime python=3.9 conda activate llmtime pip install numpy -pip install -U jax[cpu] # we don't need GPU for jax +pip install -U "jax[cpu]" # we don't need GPU for jax pip install torch --index-url https://download.pytorch.org/whl/cu118 pip install openai==0.28.1 pip install tiktoken @@ -16,4 +16,6 @@ pip install multiprocess pip install SentencePiece pip install accelerate pip install gdown +pip install mistralai #for mistral models +pip install python-dotenv #optional convenience for handling environment variables conda deactivate diff --git a/models/gpt.py b/models/gpt.py index 95c1dd8..886ddba 100644 --- a/models/gpt.py +++ b/models/gpt.py @@ -58,9 +58,9 @@ def gpt_completion_fn(model, input_str, steps, settings, num_samples, temp): allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign] allowed_tokens = [t for t in allowed_tokens if len(t) > 0] # remove empty tokens like an implicit plus sign - if (model not in ['gpt-3.5-turbo','gpt-4']): # logit bias not supported for chat models + if (model not in ['gpt-3.5-turbo','gpt-4','gpt-4-1106-preview']): # logit bias not supported for chat models logit_bias = {id: 30 for id in get_allowed_ids(allowed_tokens, model)} - if model in ['gpt-3.5-turbo','gpt-4']: + if model in ['gpt-3.5-turbo','gpt-4','gpt-4-1106-preview']: chatgpt_sys_message = "You are a helpful assistant that performs time series predictions. The user will provide a sequence and you will predict the remaining sequence. The sequence is represented by decimal strings separated by commas." extra_input = "Please continue the following sequence without producing any additional text. Do not say anything like 'the next terms in the sequence are', just return the numbers. Sequence:\n" response = openai.ChatCompletion.create( diff --git a/models/llms.py b/models/llms.py index 3e263d1..310c592 100644 --- a/models/llms.py +++ b/models/llms.py @@ -4,6 +4,13 @@ from models.llama import llama_completion_fn, llama_nll_fn from models.llama import tokenize_fn as llama_tokenize_fn +from models.mistral import mistral_completion_fn, mistral_nll_fn +from models.mistral import tokenize_fn as mistral_tokenize_fn + +from models.mistral_api import mistral_api_completion_fn, mistral_api_nll_fn +from models.mistral_api import tokenize_fn as mistral_api_tokenize_fn + + # Required: Text completion function for each model # ----------------------------------------------- # Each model is mapped to a function that samples text completions. @@ -21,7 +28,12 @@ completion_fns = { 'text-davinci-003': partial(gpt_completion_fn, model='text-davinci-003'), 'gpt-4': partial(gpt_completion_fn, model='gpt-4'), + 'gpt-4-1106-preview':partial(gpt_completion_fn, model='gpt-4-1106-preview'), 'gpt-3.5-turbo-instruct': partial(gpt_completion_fn, model='gpt-3.5-turbo-instruct'), + 'mistral': partial(mistral_completion_fn, model='mistral'), + 'mistral-api-tiny': partial(mistral_api_completion_fn, model='mistral-tiny'), + 'mistral-api-small': partial(mistral_api_completion_fn, model='mistral-small'), + 'mistral-api-medium': partial(mistral_api_completion_fn, model='mistral-medium'), 'llama-7b': partial(llama_completion_fn, model='7b'), 'llama-13b': partial(llama_completion_fn, model='13b'), 'llama-70b': partial(llama_completion_fn, model='70b'), @@ -49,6 +61,11 @@ # - float: Computed NLL per dimension for p(target_arr | input_arr). nll_fns = { 'text-davinci-003': partial(gpt_nll_fn, model='text-davinci-003'), + 'mistral': partial(mistral_nll_fn, model='mistral'), + 'mistral-api-tiny': partial(mistral_api_nll_fn, model='mistral-tiny'), + 'mistral-api-small': partial(mistral_api_nll_fn, model='mistral-small'), + 'mistral-api-medium': partial(mistral_api_nll_fn, model='mistral-medium'), + 'llama-7b': partial(llama_completion_fn, model='7b'), 'llama-7b': partial(llama_nll_fn, model='7b'), 'llama-13b': partial(llama_nll_fn, model='13b'), 'llama-70b': partial(llama_nll_fn, model='70b'), @@ -67,6 +84,10 @@ tokenization_fns = { 'text-davinci-003': partial(gpt_tokenize_fn, model='text-davinci-003'), 'gpt-3.5-turbo-instruct': partial(gpt_tokenize_fn, model='gpt-3.5-turbo-instruct'), + 'mistral': partial(mistral_tokenize_fn, model='mistral'), + 'mistral-api-tiny': partial(mistral_api_tokenize_fn, model='mistral-tiny'), + 'mistral-api-small': partial(mistral_api_tokenize_fn, model='mistral-small'), + 'mistral-api-medium': partial(mistral_api_tokenize_fn, model='mistral-medium'), 'llama-7b': partial(llama_tokenize_fn, model='7b'), 'llama-13b': partial(llama_tokenize_fn, model='13b'), 'llama-70b': partial(llama_tokenize_fn, model='70b'), @@ -79,6 +100,10 @@ context_lengths = { 'text-davinci-003': 4097, 'gpt-3.5-turbo-instruct': 4097, + 'mistral-api-tiny': 4097, + 'mistral-api-small': 4097, + 'mistral-api-medium': 4097, + 'mistral': 4096, 'llama-7b': 4096, 'llama-13b': 4096, 'llama-70b': 4096, diff --git a/models/mistral.py b/models/mistral.py new file mode 100644 index 0000000..5717207 --- /dev/null +++ b/models/mistral.py @@ -0,0 +1,148 @@ +import torch +import numpy as np +from jax import grad,vmap +from tqdm import tqdm +import argparse +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, +) +from data.serialize import serialize_arr, deserialize_str, SerializerSettings + +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + +loaded = {} + +def get_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + special_tokens_dict = dict() + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN + tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + +def get_model_and_tokenizer(model_name, cache_model=False): + if model_name in loaded: + return loaded[model_name] + tokenizer = get_tokenizer() + + model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",device_map="cpu") + model.eval() + if cache_model: + loaded[model_name] = model, tokenizer + return model, tokenizer + +def tokenize_fn(str, model): + tokenizer = get_tokenizer() + return tokenizer(str) + +def mistral_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1, cache_model=True): + """ Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM + conditioned on the input array. Applies relevant log determinant for transforms and + converts from discrete NLL of the LLM to continuous by assuming uniform within the bins. + inputs: + input_arr: (n,) context array + target_arr: (n,) ground truth array + cache_model: whether to cache the model and tokenizer for faster repeated calls + Returns: NLL/D + """ + model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model) + + input_str = serialize_arr(vmap(transform)(input_arr), settings) + target_str = serialize_arr(vmap(transform)(target_arr), settings) + full_series = input_str + target_str + + batch = tokenizer( + [full_series], + return_tensors="pt", + add_special_tokens=True + ) + batch = {k: v.cuda() for k, v in batch.items()} + + with torch.no_grad(): + out = model(**batch) + + good_tokens_str = list("0123456789" + settings.time_sep) + good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str] + bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens] + out['logits'][:,:,bad_tokens] = -100 + + input_ids = batch['input_ids'][0][1:] + input_ids = input_ids.to('cpu') + logprobs = torch.nn.functional.log_softmax(out['logits'], dim=-1)[0][:-1] + logprobs = logprobs[torch.arange(len(input_ids)), input_ids].cpu().numpy() + + + tokens = tokenizer.batch_decode( + input_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False + ) + + input_len = len(tokenizer([input_str], return_tensors="pt",)['input_ids'][0]) + input_len = input_len - 2 # remove the BOS token + + logprobs = logprobs[input_len:] + tokens = tokens[input_len:] + BPD = -logprobs.sum()/len(target_arr) + + #print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD) + # log p(x) = log p(token) - log bin_width = log p(token) + prec * log base + transformed_nll = BPD - settings.prec*np.log(settings.base) + avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean() + return transformed_nll-avg_logdet_dydx + +def mistral_completion_fn( + model, + input_str, + steps, + settings, + batch_size=5, + num_samples=20, + temp=0.9, + top_p=0.9, + cache_model=True +): + avg_tokens_per_step = len(tokenize_fn(input_str, model)['input_ids']) / len(input_str.split(settings.time_sep)) + max_tokens = int(avg_tokens_per_step*steps) + + model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model) + + gen_strs = [] + for _ in tqdm(range(num_samples // batch_size)): + batch = tokenizer( + [input_str], + return_tensors="pt", + ) + + batch = {k: v.repeat(batch_size, 1) for k, v in batch.items()} + batch = {k: v.cpu() for k, v in batch.items()} + num_input_ids = batch['input_ids'].shape[1] + + good_tokens_str = list("0123456789" + settings.time_sep) + good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str] + # good_tokens += [tokenizer.eos_token_id] + bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens] + + generate_ids = model.generate( + **batch, + do_sample=True, + max_new_tokens=max_tokens, + temperature=temp, + top_p=top_p, + bad_words_ids=[[t] for t in bad_tokens], + renormalize_logits=True, + ) + gen_strs += tokenizer.batch_decode( + generate_ids[:, num_input_ids:], + skip_special_tokens=True, + clean_up_tokenization_spaces=False + ) + return gen_strs diff --git a/models/mistral_api.py b/models/mistral_api.py new file mode 100644 index 0000000..c677921 --- /dev/null +++ b/models/mistral_api.py @@ -0,0 +1,108 @@ +from data.serialize import serialize_arr, SerializerSettings +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage +import tiktoken +import os +import numpy as np +from jax import grad,vmap + +loaded_model='' +mistral_client={} + +def init_mistral_client(model): + """ + Initialize the Mistral client for a specific LLM model. + """ + global loaded_model, mistral_client + if mistral_client == {} or loaded_model != model: + loaded_model = model + mistral_client = MistralClient(os.environ['MISTRAL_KEY']) + return mistral_client + +def tokenize_fn(str, model): + """ + Retrieve the token IDs for a string for a specific GPT model. + + Args: + str (list of str): str to be tokenized. + model (str): Name of the LLM model. + + Returns: + list of int: List of corresponding token IDs. + """ + encoding = tiktoken.encoding_for_model('gpt-3.5-turbo') + #encoding = init_mistral_client(model).embeddings(model="mistral-embed",input=str) + return encoding.encode(str) + +def get_allowed_ids(strs, model): + """ + Retrieve the token IDs for a given list of strings for a specific GPT model. + + Args: + strs (list of str): strs to be converted. + model (str): Name of the LLM model. + + Returns: + list of int: List of corresponding token IDs. + """ + encoding = tiktoken.encoding_for_model('gpt-3.5-turbo') + ids = [] + for s in strs: + id = encoding.encode(s) #init_mistral_client(model).embeddings(model="mistral-embed",input=s) + ids.extend(id) + return ids + +def mistral_api_completion_fn(model, input_str, steps, settings, num_samples, temp): + """ + Generate text completions from GPT using OpenAI's API. + + Args: + model (str): Name of the GPT-3 model to use. + input_str (str): Serialized input time series data. + steps (int): Number of time steps to predict. + settings (SerializerSettings): Serialization settings. + num_samples (int): Number of completions to generate. + temp (float): Temperature for sampling. + + Returns: + list of str: List of generated samples. + """ + avg_tokens_per_step = len(tokenize_fn(input_str, model)) / len(input_str.split(settings.time_sep)) + # define logit bias to prevent GPT-3 from producing unwanted tokens + allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] + allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign] + allowed_tokens = [t for t in allowed_tokens if len(t) > 0] # remove empty tokens like an implicit plus sign + if model in ['mistral-tiny','mistral-small','mistral-medium']: + mistral_sys_message = "You are a helpful assistant that performs time series predictions. The user will provide a sequence and you will predict the remaining sequence. The sequence is represented by decimal strings separated by commas." + extra_input = "Please continue the following sequence without producing any additional text. Do not say anything like 'the next terms in the sequence are', just return the numbers. Sequence:\n" + response = init_mistral_client(model).chat( + model=model, + messages=[ChatMessage(role="system", content = mistral_sys_message),ChatMessage(role="user", content= (extra_input+input_str+settings.time_sep))], + max_tokens=int(avg_tokens_per_step*steps), + temperature=temp, + ) + return [choice.message.content for choice in response.choices] + +def mistral_api_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1): + """ + Calculate the Negative Log-Likelihood (NLL) per dimension of the target array according to the LLM. + + Args: + model (str): Name of the LLM model to use. + input_arr (array-like): Input array (history). + target_arr (array-like): Ground target array (future). + settings (SerializerSettings): Serialization settings. + transform (callable): Transformation applied to the numerical values before serialization. + count_seps (bool, optional): Whether to account for separators in the calculation. Should be true for models that generate a variable number of digits. Defaults to True. + temp (float, optional): Temperature for sampling. Defaults to 1. + + Returns: + float: Calculated NLL per dimension. + """ + input_str = serialize_arr(vmap(transform)(input_arr), settings) + target_str = serialize_arr(vmap(transform)(target_arr), settings) + assert input_str.endswith(settings.time_sep), f'Input string must end with {settings.time_sep}, got {input_str}' + full_series = input_str + target_str + response = init_mistral_client(model).chat_stream(model=model, messages=[ChatMessage(role="user",content=full_series)], max_tokens=0, temperature=temp,) + #print(response['choices'][0]) + return -1 diff --git a/models/promptcast.py b/models/promptcast.py index ecdcea5..1835198 100644 --- a/models/promptcast.py +++ b/models/promptcast.py @@ -273,11 +273,15 @@ def get_promptcast_predictions_data(train, test, model, settings, num_samples=10 medians = None completions_list = None input_strs = None + + if kwargs.get('parallel') is None: + kwargs = {**kwargs, 'parallel': True} + if num_samples > 0: # Generate predictions preds, completions_list, input_strs = generate_predictions(model, inputs, steps, settings, scalers, num_samples=num_samples, temp=temp, prompts=prompts, post_prompts=post_prompts, - parallel=True, return_input_strs=True, constrain_tokens=False, strict_handling=True, **kwargs) + return_input_strs=True, constrain_tokens=False, strict_handling=True, **kwargs) # skip bad samples samples = [pd.DataFrame(np.array([p for p in preds[i] if p is not None]), columns=test[i].index) for i in range(len(preds))] medians = [sample.median(axis=0) for sample in samples]