Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding configurable premise #47

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 15 additions & 2 deletions api/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def __init__(self,
input_suffix="\n",
output_prefix="output: ",
output_suffix="\n\n",
append_output_prefix_to_query=False):
append_output_prefix_to_query=False,
premise_prefix="",
premise_suffix="\n\n"):
self.examples = {}
self.premise = ""
self.engine = engine
self.temperature = temperature
self.max_tokens = max_tokens
Expand All @@ -60,6 +63,8 @@ def __init__(self,
self.output_prefix = output_prefix
self.output_suffix = output_suffix
self.append_output_prefix_to_query = append_output_prefix_to_query
self.premise_prefix = premise_prefix
self.premise_suffix = premise_suffix
self.stop = (output_suffix + input_prefix).strip()

def add_example(self, ex):
Expand All @@ -70,6 +75,10 @@ def add_example(self, ex):
assert isinstance(ex, Example), "Please create an Example object."
self.examples[ex.get_id()] = ex

def set_premise(self, premise):
"""Sets a premise on the object. """
self.premise = premise

def delete_example(self, id):
"""Delete example with the specific id."""
if id in self.examples:
Expand Down Expand Up @@ -102,7 +111,11 @@ def get_max_tokens(self):

def craft_query(self, prompt):
"""Creates the query for the API request."""
q = self.get_prime_text(
if self.premise:
q = self.premise_prefix + self.premise + self.premise_suffix
else:
q = ""
q = q + self.get_prime_text(
) + self.input_prefix + prompt + self.input_suffix
if self.append_output_prefix_to_query:
q = q + self.output_prefix
Expand Down
40 changes: 40 additions & 0 deletions examples/run_twitter_fiction_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

from api import GPT, Example, UIConfig
from api import demo_web_app

PROMPT_EXAMPLE_URL = "https://raw.githubusercontent.com/ml4j/gpt-scrolls/master/tweets/twitter-fiction-prompt.json"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per comment below, these external URLs are really for demonstration purposes for this PR review - I'd suggest committing the raw json files as part of this PR should you wish to merge this into the codebase)

TEMPLATE_EXAMPLE_URL = "https://raw.githubusercontent.com/ml4j/gtp-3-prompt-templates/master/question-answer/default/templates/question_answer_template_2.json"

import requests
import json

prompt_example_json = json.loads(requests.get(PROMPT_EXAMPLE_URL).text)
template_json = json.loads(requests.get(TEMPLATE_EXAMPLE_URL).text)

# Construct GPT object and show some examples
gpt = GPT(engine="davinci",
temperature=1.1,
max_tokens=100,
input_prefix=template_json['questionPrefix'],
input_suffix=template_json['questionSuffix'],
output_prefix=template_json['answerPrefix'],
output_suffix=template_json['answerSuffix'],
append_output_prefix_to_query=False,
premise_prefix=template_json['premisePrefix'],
premise_suffix=template_json['premiseSuffix'])

gpt.set_premise(prompt_example_json['premise'])

for example in prompt_example_json['questionsAndAnswers']:
gpt.add_example(Example(example['question'], example['answer']))

# Define UI configuration
config = UIConfig(description="Twitter Fiction",
button_text="Generate",
placeholder=prompt_example_json['defaultPromptQuestion'])


demo_web_app(gpt, config)