From b8ec499b9deb7103aebe23d47aa89acaa53d3805 Mon Sep 17 00:00:00 2001 From: "ellipsis-dev[bot]" <65095814+ellipsis-dev[bot]@users.noreply.github.com> Date: Mon, 25 Mar 2024 19:24:28 +0000 Subject: [PATCH] implement #18; --- sunholo/components/llm.py | 46 +++++++++++++++++++++++++++++++++++- sunholo/components/prompt.py | 7 +++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/sunholo/components/llm.py b/sunholo/components/llm.py index 1d55e4c..944ba07 100644 --- a/sunholo/components/llm.py +++ b/sunholo/components/llm.py @@ -16,7 +16,14 @@ logging = setup_logging() + def pick_llm(vector_name): + """ + This function selects a language model based on a given vector name. It returns the selected language model, embeddings, and chat model. + + :param vector_name: The name of the vector used to select the language model. + :return: A tuple containing the selected language model, embeddings, and chat model. + """ logging.debug('Picking llm') llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml") @@ -47,7 +54,14 @@ def pick_llm(vector_name): return llm, embeddings, llm_chat + def pick_streaming(vector_name): + """ + This function determines whether to use streaming based on the given vector name. + + :param vector_name: The name of the vector used to determine whether to use streaming. + :return: A boolean indicating whether to use streaming. + """ llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml") @@ -57,7 +71,15 @@ def pick_streaming(vector_name): return False + def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"): + """ + This function gets a language model based on a given vector name and an optional model name. + + :param vector_name: The name of the vector used to get the language model. + :param model: The name of the model. If not provided, a default model is used. + :return: The selected language model. + """ llm_str = load_config_key("llm", vector_name, filename=config_file) model_lookup_filepath = get_module_filepath("lookup/model_lookup.yaml") model_lookup, _ = load_config(model_lookup_filepath) @@ -106,7 +128,15 @@ def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"): if llm_str is None: raise NotImplementedError(f'No llm implemented for {llm_str}') + def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"): + """ + This function gets a chat model based on a given vector name and an optional model name. + + :param vector_name: The name of the vector used to get the chat model. + :param model: The name of the model. If not provided, a default model is used. + :return: The selected chat model. + """ llm_str = load_config_key("llm", vector_name, filename=config_file) if not model: model = load_config_key("model", vector_name, filename=config_file) @@ -150,14 +180,28 @@ def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"): if llm_str is None: raise NotImplementedError(f'No llm implemented for {llm_str}') + def get_embeddings(vector_name): + """ + This function gets embeddings based on a given vector name. + + :param vector_name: The name of the vector used to get the embeddings. + :return: The selected embeddings. + """ llm_str = load_config_key("llm", vector_name, filename="config/llm_config.yaml") return pick_embedding(llm_str) + def pick_embedding(llm_str: str): + """ + This function selects embeddings based on a given language model string. + + :param llm_str: The language model string used to select the embeddings. + :return: The selected embeddings. + """ # get embedding directly from llm_str # Configure embeddings based on llm_str if llm_str == 'openai': @@ -175,4 +219,4 @@ def pick_embedding(llm_str: str): return GoogleGenerativeAIEmbeddings(model="models/embedding-001") #TODO add embedding type if llm_str is None: - raise NotImplementedError(f'No embeddings implemented for {llm_str}') + raise NotImplementedError(f'No embeddings implemented for {llm_str}') \ No newline at end of file diff --git a/sunholo/components/prompt.py b/sunholo/components/prompt.py index 157cccd..639d464 100644 --- a/sunholo/components/prompt.py +++ b/sunholo/components/prompt.py @@ -23,6 +23,7 @@ def pick_prompt(vector_name, chat_history=[]): + """This function selects a custom prompt based on a given vector name and an optional chat history. It returns a PromptTemplate object.""" """Pick a custom prompt""" logging.debug('Picking prompt') @@ -92,6 +93,7 @@ def pick_prompt(vector_name, chat_history=[]): return QA_PROMPT def pick_chat_buddy(vector_name): + """This function selects a chat buddy based on a given vector name. It returns the name of the chat buddy and a description of the chat buddy.""" chat_buddy = load_config_key("chat_buddy", vector_name, filename = "config/llm_config.yaml") if chat_buddy is not None: logging.info(f"Got chat buddy {chat_buddy} for {vector_name}") @@ -101,6 +103,7 @@ def pick_chat_buddy(vector_name): def pick_agent(vector_name): + """This function determines whether to use an agent based on a given vector name. It returns a boolean indicating whether to use an agent.""" agent_str = load_config_key("agent", vector_name, filename = "config/llm_config.yaml") if agent_str == "yes": return True @@ -108,12 +111,14 @@ def pick_agent(vector_name): return False def pick_shared_vectorstore(vector_name, embeddings): + """This function selects a shared vector store based on a given vector name and embeddings. It returns the selected vector store.""" shared_vectorstore = load_config_key("shared_vectorstore", vector_name, filename = "config/llm_config.yaml") vectorstore = pick_vectorstore(shared_vectorstore, embeddings) return vectorstore def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -> str: + """This function gets the chat history based on given inputs, a vector name, and optional parameters for the number of last characters and summary characters. It returns a string representing the chat history.""" from langchain.schema import Document from ..summarise import summarise_docs @@ -148,4 +153,4 @@ def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) - summary = text_sum[:summary_chars] # Concatenate the summary and the last `last_chars` characters of the chat history - return summary + "\n### Recent Chat History\n..." + recent_history + return summary + "\n### Recent Chat History\n..." + recent_history \ No newline at end of file