Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [Ellipsis] Documentation bots #21

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 45 additions & 1 deletion sunholo/components/llm.py
Expand Up @@ -16,7 +16,14 @@

logging = setup_logging()


def pick_llm(vector_name):
"""
This function selects a language model based on a given vector name. It returns the selected language model, embeddings, and chat model.

:param vector_name: The name of the vector used to select the language model.
:return: A tuple containing the selected language model, embeddings, and chat model.
"""
logging.debug('Picking llm')

llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml")
Expand Down Expand Up @@ -47,7 +54,14 @@ def pick_llm(vector_name):

return llm, embeddings, llm_chat


def pick_streaming(vector_name):
"""
This function determines whether to use streaming based on the given vector name.

:param vector_name: The name of the vector used to determine whether to use streaming.
:return: A boolean indicating whether to use streaming.
"""

llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml")

Expand All @@ -57,7 +71,15 @@ def pick_streaming(vector_name):
return False



def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"):
"""
This function gets a language model based on a given vector name and an optional model name.

:param vector_name: The name of the vector used to get the language model.
:param model: The name of the model. If not provided, a default model is used.
:return: The selected language model.
"""
llm_str = load_config_key("llm", vector_name, filename=config_file)
model_lookup_filepath = get_module_filepath("lookup/model_lookup.yaml")
model_lookup, _ = load_config(model_lookup_filepath)
Expand Down Expand Up @@ -106,7 +128,15 @@ def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"):
if llm_str is None:
raise NotImplementedError(f'No llm implemented for {llm_str}')


def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"):
"""
This function gets a chat model based on a given vector name and an optional model name.

:param vector_name: The name of the vector used to get the chat model.
:param model: The name of the model. If not provided, a default model is used.
:return: The selected chat model.
"""
llm_str = load_config_key("llm", vector_name, filename=config_file)
if not model:
model = load_config_key("model", vector_name, filename=config_file)
Expand Down Expand Up @@ -150,14 +180,28 @@ def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"):
if llm_str is None:
raise NotImplementedError(f'No llm implemented for {llm_str}')


def get_embeddings(vector_name):
"""
This function gets embeddings based on a given vector name.

:param vector_name: The name of the vector used to get the embeddings.
:return: The selected embeddings.
"""
llm_str = load_config_key("llm", vector_name, filename="config/llm_config.yaml")

return pick_embedding(llm_str)




def pick_embedding(llm_str: str):
"""
This function selects embeddings based on a given language model string.

:param llm_str: The language model string used to select the embeddings.
:return: The selected embeddings.
"""
# get embedding directly from llm_str
# Configure embeddings based on llm_str
if llm_str == 'openai':
Expand All @@ -175,4 +219,4 @@ def pick_embedding(llm_str: str):
return GoogleGenerativeAIEmbeddings(model="models/embedding-001") #TODO add embedding type

if llm_str is None:
raise NotImplementedError(f'No embeddings implemented for {llm_str}')
raise NotImplementedError(f'No embeddings implemented for {llm_str}')
7 changes: 6 additions & 1 deletion sunholo/components/prompt.py
Expand Up @@ -23,6 +23,7 @@


def pick_prompt(vector_name, chat_history=[]):
"""This function selects a custom prompt based on a given vector name and an optional chat history. It returns a PromptTemplate object."""
"""Pick a custom prompt"""
logging.debug('Picking prompt')

Expand Down Expand Up @@ -92,6 +93,7 @@ def pick_prompt(vector_name, chat_history=[]):
return QA_PROMPT

def pick_chat_buddy(vector_name):
"""This function selects a chat buddy based on a given vector name. It returns the name of the chat buddy and a description of the chat buddy."""
chat_buddy = load_config_key("chat_buddy", vector_name, filename = "config/llm_config.yaml")
if chat_buddy is not None:
logging.info(f"Got chat buddy {chat_buddy} for {vector_name}")
Expand All @@ -101,19 +103,22 @@ def pick_chat_buddy(vector_name):


def pick_agent(vector_name):
"""This function determines whether to use an agent based on a given vector name. It returns a boolean indicating whether to use an agent."""
agent_str = load_config_key("agent", vector_name, filename = "config/llm_config.yaml")
if agent_str == "yes":
return True

return False

def pick_shared_vectorstore(vector_name, embeddings):
"""This function selects a shared vector store based on a given vector name and embeddings. It returns the selected vector store."""
shared_vectorstore = load_config_key("shared_vectorstore", vector_name, filename = "config/llm_config.yaml")
vectorstore = pick_vectorstore(shared_vectorstore, embeddings)
return vectorstore


def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -> str:
"""This function gets the chat history based on given inputs, a vector name, and optional parameters for the number of last characters and summary characters. It returns a string representing the chat history."""
from langchain.schema import Document
from ..summarise import summarise_docs

Expand Down Expand Up @@ -148,4 +153,4 @@ def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -
summary = text_sum[:summary_chars]

# Concatenate the summary and the last `last_chars` characters of the chat history
return summary + "\n### Recent Chat History\n..." + recent_history
return summary + "\n### Recent Chat History\n..." + recent_history