Skip to content

Commit

Permalink
feat: add OpenAIChat impl with gpt-4 support (#30)
Browse files Browse the repository at this point in the history
* feat: add OpenAIChat impl with gpt-4 support

* use fixed dep

---------

Co-authored-by: Douglas Reid <doug@steamship.com>
  • Loading branch information
douglas-reid and Douglas Reid committed Mar 20, 2023
1 parent 1a00c55 commit 2c6be2f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
steamship==2.13.19
steamship==2.14.1
langchain==0.0.104
tiktoken==0.2.0
pydantic==1.10.2
106 changes: 103 additions & 3 deletions src/steamship_langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import tiktoken
from langchain.llms.base import Generation, LLMResult
from langchain.llms.openai import BaseOpenAI
from pydantic import root_validator
from steamship import Block, File, Steamship, SteamshipError
from langchain.llms.openai import OpenAIChat as BaseOpenAIChat
from pydantic import Extra, root_validator
from steamship import Block, File, MimeTypes, PluginInstance, Steamship, SteamshipError, Tag
from steamship.data import TagKind, TagValueKey
from steamship.data.tags.tag_constants import RoleTag

PLUGIN_HANDLE: str = "gpt-3"
ARGUMENT_WHITELIST = {
Expand Down Expand Up @@ -91,7 +93,6 @@ def add_default_request_timeout(cls, values: Dict[str, Any]) -> Dict[str, Any]:

@root_validator(pre=True)
def raise_on_unsupported_arguments(cls, values: Dict[str, Any]) -> Dict[str, Any]: # noqa: N805

if unsupported_arguments := set(values.keys()) - ARGUMENT_WHITELIST:
raise NotImplementedError(f"Found unsupported argument: {unsupported_arguments}")
return values
Expand Down Expand Up @@ -188,3 +189,102 @@ def _batch(
pass

return generations, token_usage


class OpenAIChat(BaseOpenAIChat):
_llm_plugin: PluginInstance

class Config:
"""Configuration for this pydantic object."""

extra = Extra.allow

def __init__(
self, client: Steamship, model_name: str = "gpt-4", moderate_output: bool = True, **kwargs
):
super().__init__(client=client, model_name=model_name, **kwargs)
plugin_config = {"model": self.model_name, "moderate_output": moderate_output}
if self.openai_api_key:
plugin_config["openai_api_key"] = self.openai_api_key

model_args = self.model_kwargs
for arg in [
"max_tokens",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"max_retries",
]:
if model_args.get(arg):
plugin_config[arg] = model_args[arg]

self._llm_plugin = self.client.use_plugin(
plugin_handle="gpt-4",
config=plugin_config,
fetch_if_exists=True,
)

@root_validator()
def validate_environment(cls, values: Dict) -> Dict: # noqa: N805
return values

def _completion(self, messages: [Dict[str, str]], **params) -> str:
blocks = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")

if len(content) > 0:
role_tag = RoleTag.USER
if role.lower() == "system":
role_tag = RoleTag.SYSTEM
elif role.lower() == "assistant":
role_tag = RoleTag.ASSISTANT
blocks.append(
Block(
text=content,
tags=[Tag(kind=TagKind.ROLE, name=role_tag)],
mime_type=MimeTypes.TXT,
)
)

file = File.create(self.client, blocks=blocks)
generate_task = self._llm_plugin.generate(input_file_id=file.id, options=params)
generate_task.wait()
return generate_task.output.blocks[0].text

def _generate(self, prompts: List[str], stop: Optional[List[str]] = None) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
generated_text = self._completion(messages=messages, **params)
return LLMResult(
generations=[[Generation(text=generated_text)]],
# TODO(dougreid): token usage calculations !!!
)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{
"model_name": self.model_name,
"workspace_handle": self.client.get_workspace().handle,
"plugin_handle": "gpt-4",
},
**self._default_params,
}

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "steamship-openai-chat"

def get_num_tokens(self, text: str) -> int:
"""Calculate num tokens with tiktoken package."""
encoder = "p50k_base"
enc = tiktoken.get_encoding(encoder)
tokenized_text = enc.encode(text)
return len(tokenized_text)

async def agenerate(self, prompts: List[str], stop: Optional[List[str]] = None) -> LLMResult:
raise NotImplementedError("Support for async is not provided yet.")
34 changes: 33 additions & 1 deletion tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.llms.loading import load_llm
from steamship import Steamship

from steamship_langchain.llms.openai import OpenAI
from steamship_langchain.llms.openai import OpenAI, OpenAIChat


@pytest.mark.usefixtures("client")
Expand Down Expand Up @@ -121,3 +121,35 @@ def test_openai_streaming_unsupported(client: Steamship) -> None:
llm = OpenAI(client=client, max_tokens=10)
with pytest.raises(NotImplementedError):
llm.stream("I'm Pickle Rick")


@pytest.mark.usefixtures("client")
def test_openai_chat_llm(client: Steamship) -> None:
"""Test Chat version of the LLM"""
llm = OpenAIChat(client=client)
llm_result = llm.generate(prompts=["Please say the Pledge of Allegiance"], stop=["flag"])
assert len(llm_result.generations) == 1
generation = llm_result.generations[0]
assert len(generation) == 1
text_response = generation[0].text
assert text_response.strip() == "I pledge allegiance to the"


@pytest.mark.usefixtures("client")
def test_openai_chat_llm_with_prefixed_messages(client: Steamship) -> None:
"""Test Chat version of the LLM"""
messages = [
{
"role": "system",
"content": "You are EchoGPT. For every prompt you receive, you reply with the exact same text.",
},
{"role": "user", "content": "This is a test."},
{"role": "assistant", "content": "This is a test."},
]
llm = OpenAIChat(client=client, prefix_messages=messages)
llm_result = llm.generate(prompts=["What is the meaning of life?"])
assert len(llm_result.generations) == 1
generation = llm_result.generations[0]
assert len(generation) == 1
text_response = generation[0].text
assert text_response.strip() == "What is the meaning of life?"

0 comments on commit 2c6be2f

Please sign in to comment.