From 1b777bfc8c8bde571e77e245858c94bad6020781 Mon Sep 17 00:00:00 2001 From: "ellipsis-dev[bot]" <65095814+ellipsis-dev[bot]@users.noreply.github.com> Date: Mon, 25 Mar 2024 16:42:53 +0000 Subject: [PATCH] implement #18; --- sunholo/components/README.md | 22 ++++++++++++++ sunholo/components/llm.py | 48 ++++++++++++++++++++++++++++++- sunholo/components/prompt.py | 6 +++- sunholo/components/retriever.py | 4 ++- sunholo/components/vectorstore.py | 10 +++++-- 5 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 sunholo/components/README.md diff --git a/sunholo/components/README.md b/sunholo/components/README.md new file mode 100644 index 0000000..70aecb4 --- /dev/null +++ b/sunholo/components/README.md @@ -0,0 +1,22 @@ +# Sunholo Components + +This directory contains the components used in the Sunholo project. Each file represents a different component and contains functions related to that component. + +## Files in this directory + +- `file1.py`: This file contains functions for ... +- `file2.py`: This file contains functions for ... + +## Functions + +### file1.py + +- `function1`: This function ... +- `function2`: This function ... + +### file2.py + +- `function1`: This function ... +- `function2`: This function ... + +Please refer to the docstrings in each function for more detailed information. \ No newline at end of file diff --git a/sunholo/components/llm.py b/sunholo/components/llm.py index 1d55e4c..314af17 100644 --- a/sunholo/components/llm.py +++ b/sunholo/components/llm.py @@ -16,7 +16,14 @@ logging = setup_logging() + def pick_llm(vector_name): + """ + This function picks a language model based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen language model, embeddings, and a chat model. + + :param vector_name: The name of the vector to be used. + :return: A tuple containing the chosen language model, embeddings, and a chat model. + """ logging.debug('Picking llm') llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml") @@ -47,7 +54,14 @@ def pick_llm(vector_name): return llm, embeddings, llm_chat + def pick_streaming(vector_name): + """ + This function checks if the 'llm' key in the 'config/llm_config.yaml' file is either 'openai', 'gemini', or 'vertex'. If it is, the function returns True. Otherwise, it returns False. + + :param vector_name: The name of the vector to be used. + :return: True if the 'llm' key is either 'openai', 'gemini', or 'vertex'. False otherwise. + """ llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml") @@ -57,7 +71,16 @@ def pick_streaming(vector_name): return False + def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"): + """ + This function gets a language model based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen language model. + + :param vector_name: The name of the vector to be used. + :param model: The model to be used. If not provided, a model is loaded from the 'config/llm_config.yaml' file. + :param config_file: The configuration file to be used. Default is 'config/llm_config.yaml'. + :return: The chosen language model. + """ llm_str = load_config_key("llm", vector_name, filename=config_file) model_lookup_filepath = get_module_filepath("lookup/model_lookup.yaml") model_lookup, _ = load_config(model_lookup_filepath) @@ -106,7 +129,16 @@ def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"): if llm_str is None: raise NotImplementedError(f'No llm implemented for {llm_str}') + def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"): + """ + This function gets a chat model based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen chat model. + + :param vector_name: The name of the vector to be used. + :param model: The model to be used. If not provided, a model is loaded from the 'config/llm_config.yaml' file. + :param config_file: The configuration file to be used. Default is 'config/llm_config.yaml'. + :return: The chosen chat model. + """ llm_str = load_config_key("llm", vector_name, filename=config_file) if not model: model = load_config_key("model", vector_name, filename=config_file) @@ -150,14 +182,28 @@ def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"): if llm_str is None: raise NotImplementedError(f'No llm implemented for {llm_str}') + def get_embeddings(vector_name): + """ + This function gets embeddings based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen embeddings. + + :param vector_name: The name of the vector to be used. + :return: The chosen embeddings. + """ llm_str = load_config_key("llm", vector_name, filename="config/llm_config.yaml") return pick_embedding(llm_str) + def pick_embedding(llm_str: str): + """ + This function picks embeddings based on the 'llm' key. It returns the chosen embeddings. + + :param llm_str: The 'llm' key to be used. + :return: The chosen embeddings. + """ # get embedding directly from llm_str # Configure embeddings based on llm_str if llm_str == 'openai': @@ -175,4 +221,4 @@ def pick_embedding(llm_str: str): return GoogleGenerativeAIEmbeddings(model="models/embedding-001") #TODO add embedding type if llm_str is None: - raise NotImplementedError(f'No embeddings implemented for {llm_str}') + raise NotImplementedError(f'No embeddings implemented for {llm_str}') \ No newline at end of file diff --git a/sunholo/components/prompt.py b/sunholo/components/prompt.py index 157cccd..1412456 100644 --- a/sunholo/components/prompt.py +++ b/sunholo/components/prompt.py @@ -92,6 +92,7 @@ def pick_prompt(vector_name, chat_history=[]): return QA_PROMPT def pick_chat_buddy(vector_name): + """This function picks a chat buddy based on the 'chat_buddy' key in the 'config/llm_config.yaml' file. It returns the chosen chat buddy and a description of the chat buddy.""" chat_buddy = load_config_key("chat_buddy", vector_name, filename = "config/llm_config.yaml") if chat_buddy is not None: logging.info(f"Got chat buddy {chat_buddy} for {vector_name}") @@ -101,6 +102,7 @@ def pick_chat_buddy(vector_name): def pick_agent(vector_name): + """This function checks if the 'agent' key in the 'config/llm_config.yaml' file is 'yes'. If it is, the function returns True. Otherwise, it returns False.""" agent_str = load_config_key("agent", vector_name, filename = "config/llm_config.yaml") if agent_str == "yes": return True @@ -108,12 +110,14 @@ def pick_agent(vector_name): return False def pick_shared_vectorstore(vector_name, embeddings): + """This function picks a shared vector store based on the 'shared_vectorstore' key in the 'config/llm_config.yaml' file. It returns the chosen vector store.""" shared_vectorstore = load_config_key("shared_vectorstore", vector_name, filename = "config/llm_config.yaml") vectorstore = pick_vectorstore(shared_vectorstore, embeddings) return vectorstore def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -> str: + """This function gets the chat history and returns a string that contains the summarized chat history and the last 'last_chars' characters of the chat history. The function takes as parameters the chat history, the vector name, the number of last characters to include in the chat history, and the number of characters to include in the summary.""" from langchain.schema import Document from ..summarise import summarise_docs @@ -148,4 +152,4 @@ def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) - summary = text_sum[:summary_chars] # Concatenate the summary and the last `last_chars` characters of the chat history - return summary + "\n### Recent Chat History\n..." + recent_history + return summary + "\n### Recent Chat History\n..." + recent_history \ No newline at end of file diff --git a/sunholo/components/retriever.py b/sunholo/components/retriever.py index 466bf58..d18a4ac 100644 --- a/sunholo/components/retriever.py +++ b/sunholo/components/retriever.py @@ -27,6 +27,7 @@ logging = setup_logging() def load_memories(vector_name): + """This function loads memory settings for a given vector name from the 'config/llm_config.yaml' file. It returns the memory settings for the vector name. If no memory settings are found, it returns None.""" memories = load_config_key("memory", vector_name, filename="config/llm_config.yaml") logging.info(f"Found memory settings for {vector_name}: {memories}") if len(memories) == 0: @@ -36,6 +37,7 @@ def load_memories(vector_name): return memories def pick_retriever(vector_name, embeddings=None): + """This function creates a list of retrievers based on the memory settings for a given vector name. It returns a ContextualCompressionRetriever object. If no retrievers are created, it returns None. The function takes a vector name and an optional embeddings parameter.""" memories = load_memories(vector_name) @@ -78,4 +80,4 @@ def pick_retriever(vector_name, embeddings=None): base_compressor=pipeline, base_retriever=lotr, k=3) - return retriever + return retriever \ No newline at end of file diff --git a/sunholo/components/vectorstore.py b/sunholo/components/vectorstore.py index 239f8bf..42b522a 100644 --- a/sunholo/components/vectorstore.py +++ b/sunholo/components/vectorstore.py @@ -18,7 +18,11 @@ logging = setup_logging() +def load_memories(vector_name): + """This function loads memory settings for a given vector name from the 'config/llm_config.yaml' file and returns the memory settings for the vector name. If no memory settings are found, it returns None.""" + def pick_vectorstore(vs_str, vector_name, embeddings): + """This function picks a vector store based on the 'vs_str' parameter. It currently supports 'supabase', 'cloudsql', 'alloydb', and 'lancedb'.""" logging.debug('Picking vectorstore') if vs_str == 'supabase': @@ -69,7 +73,9 @@ def pick_vectorstore(vs_str, vector_name, embeddings): logging.debug("Chose CloudSQL") return vectorstore - + +def pick_retriever(vector_name, embeddings=None): + """This function creates a list of retrievers based on the memory settings for a given vector name and returns a ContextualCompressionRetriever object. If no retrievers are created, it returns None. The function takes a vector name and an optional embeddings parameter.""" elif vs_str == 'alloydb': from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore from google.cloud.alloydb.connector import IPTypes @@ -139,4 +145,4 @@ def pick_vectorstore(vs_str, vector_name, embeddings): return vectorstore else: - raise NotImplementedError(f'No llm implemented for {vs_str}') + raise NotImplementedError(f'No llm implemented for {vs_str}') \ No newline at end of file