Skip to content

Commit

Permalink
implement #18;
Browse files Browse the repository at this point in the history
  • Loading branch information
ellipsis-dev[bot] committed Mar 25, 2024
1 parent 74110c4 commit 1b777bf
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
22 changes: 22 additions & 0 deletions sunholo/components/README.md
@@ -0,0 +1,22 @@
# Sunholo Components

This directory contains the components used in the Sunholo project. Each file represents a different component and contains functions related to that component.

## Files in this directory

- `file1.py`: This file contains functions for ...
- `file2.py`: This file contains functions for ...

## Functions

### file1.py

- `function1`: This function ...
- `function2`: This function ...

### file2.py

- `function1`: This function ...
- `function2`: This function ...

Please refer to the docstrings in each function for more detailed information.
48 changes: 47 additions & 1 deletion sunholo/components/llm.py
Expand Up @@ -16,7 +16,14 @@

logging = setup_logging()


def pick_llm(vector_name):
"""
This function picks a language model based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen language model, embeddings, and a chat model.
:param vector_name: The name of the vector to be used.
:return: A tuple containing the chosen language model, embeddings, and a chat model.
"""
logging.debug('Picking llm')

llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml")
Expand Down Expand Up @@ -47,7 +54,14 @@ def pick_llm(vector_name):

return llm, embeddings, llm_chat


def pick_streaming(vector_name):
"""
This function checks if the 'llm' key in the 'config/llm_config.yaml' file is either 'openai', 'gemini', or 'vertex'. If it is, the function returns True. Otherwise, it returns False.
:param vector_name: The name of the vector to be used.
:return: True if the 'llm' key is either 'openai', 'gemini', or 'vertex'. False otherwise.
"""

llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml")

Expand All @@ -57,7 +71,16 @@ def pick_streaming(vector_name):
return False



def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"):
"""
This function gets a language model based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen language model.
:param vector_name: The name of the vector to be used.
:param model: The model to be used. If not provided, a model is loaded from the 'config/llm_config.yaml' file.
:param config_file: The configuration file to be used. Default is 'config/llm_config.yaml'.
:return: The chosen language model.
"""
llm_str = load_config_key("llm", vector_name, filename=config_file)
model_lookup_filepath = get_module_filepath("lookup/model_lookup.yaml")
model_lookup, _ = load_config(model_lookup_filepath)
Expand Down Expand Up @@ -106,7 +129,16 @@ def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"):
if llm_str is None:
raise NotImplementedError(f'No llm implemented for {llm_str}')


def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"):
"""
This function gets a chat model based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen chat model.
:param vector_name: The name of the vector to be used.
:param model: The model to be used. If not provided, a model is loaded from the 'config/llm_config.yaml' file.
:param config_file: The configuration file to be used. Default is 'config/llm_config.yaml'.
:return: The chosen chat model.
"""
llm_str = load_config_key("llm", vector_name, filename=config_file)
if not model:
model = load_config_key("model", vector_name, filename=config_file)
Expand Down Expand Up @@ -150,14 +182,28 @@ def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"):
if llm_str is None:
raise NotImplementedError(f'No llm implemented for {llm_str}')


def get_embeddings(vector_name):
"""
This function gets embeddings based on the 'llm' key in the 'config/llm_config.yaml' file. It returns the chosen embeddings.
:param vector_name: The name of the vector to be used.
:return: The chosen embeddings.
"""
llm_str = load_config_key("llm", vector_name, filename="config/llm_config.yaml")

return pick_embedding(llm_str)




def pick_embedding(llm_str: str):
"""
This function picks embeddings based on the 'llm' key. It returns the chosen embeddings.
:param llm_str: The 'llm' key to be used.
:return: The chosen embeddings.
"""
# get embedding directly from llm_str
# Configure embeddings based on llm_str
if llm_str == 'openai':
Expand All @@ -175,4 +221,4 @@ def pick_embedding(llm_str: str):
return GoogleGenerativeAIEmbeddings(model="models/embedding-001") #TODO add embedding type

if llm_str is None:
raise NotImplementedError(f'No embeddings implemented for {llm_str}')
raise NotImplementedError(f'No embeddings implemented for {llm_str}')
6 changes: 5 additions & 1 deletion sunholo/components/prompt.py
Expand Up @@ -92,6 +92,7 @@ def pick_prompt(vector_name, chat_history=[]):
return QA_PROMPT

def pick_chat_buddy(vector_name):
"""This function picks a chat buddy based on the 'chat_buddy' key in the 'config/llm_config.yaml' file. It returns the chosen chat buddy and a description of the chat buddy."""
chat_buddy = load_config_key("chat_buddy", vector_name, filename = "config/llm_config.yaml")
if chat_buddy is not None:
logging.info(f"Got chat buddy {chat_buddy} for {vector_name}")
Expand All @@ -101,19 +102,22 @@ def pick_chat_buddy(vector_name):


def pick_agent(vector_name):
"""This function checks if the 'agent' key in the 'config/llm_config.yaml' file is 'yes'. If it is, the function returns True. Otherwise, it returns False."""
agent_str = load_config_key("agent", vector_name, filename = "config/llm_config.yaml")
if agent_str == "yes":
return True

return False

def pick_shared_vectorstore(vector_name, embeddings):
"""This function picks a shared vector store based on the 'shared_vectorstore' key in the 'config/llm_config.yaml' file. It returns the chosen vector store."""
shared_vectorstore = load_config_key("shared_vectorstore", vector_name, filename = "config/llm_config.yaml")
vectorstore = pick_vectorstore(shared_vectorstore, embeddings)
return vectorstore


def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -> str:
"""This function gets the chat history and returns a string that contains the summarized chat history and the last 'last_chars' characters of the chat history. The function takes as parameters the chat history, the vector name, the number of last characters to include in the chat history, and the number of characters to include in the summary."""
from langchain.schema import Document
from ..summarise import summarise_docs

Expand Down Expand Up @@ -148,4 +152,4 @@ def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -
summary = text_sum[:summary_chars]

# Concatenate the summary and the last `last_chars` characters of the chat history
return summary + "\n### Recent Chat History\n..." + recent_history
return summary + "\n### Recent Chat History\n..." + recent_history
4 changes: 3 additions & 1 deletion sunholo/components/retriever.py
Expand Up @@ -27,6 +27,7 @@
logging = setup_logging()

def load_memories(vector_name):
"""This function loads memory settings for a given vector name from the 'config/llm_config.yaml' file. It returns the memory settings for the vector name. If no memory settings are found, it returns None."""
memories = load_config_key("memory", vector_name, filename="config/llm_config.yaml")
logging.info(f"Found memory settings for {vector_name}: {memories}")
if len(memories) == 0:
Expand All @@ -36,6 +37,7 @@ def load_memories(vector_name):
return memories

def pick_retriever(vector_name, embeddings=None):
"""This function creates a list of retrievers based on the memory settings for a given vector name. It returns a ContextualCompressionRetriever object. If no retrievers are created, it returns None. The function takes a vector name and an optional embeddings parameter."""

memories = load_memories(vector_name)

Expand Down Expand Up @@ -78,4 +80,4 @@ def pick_retriever(vector_name, embeddings=None):
base_compressor=pipeline, base_retriever=lotr,
k=3)

return retriever
return retriever
10 changes: 8 additions & 2 deletions sunholo/components/vectorstore.py
Expand Up @@ -18,7 +18,11 @@
logging = setup_logging()


def load_memories(vector_name):
"""This function loads memory settings for a given vector name from the 'config/llm_config.yaml' file and returns the memory settings for the vector name. If no memory settings are found, it returns None."""

def pick_vectorstore(vs_str, vector_name, embeddings):
"""This function picks a vector store based on the 'vs_str' parameter. It currently supports 'supabase', 'cloudsql', 'alloydb', and 'lancedb'."""
logging.debug('Picking vectorstore')

if vs_str == 'supabase':
Expand Down Expand Up @@ -69,7 +73,9 @@ def pick_vectorstore(vs_str, vector_name, embeddings):
logging.debug("Chose CloudSQL")

return vectorstore


def pick_retriever(vector_name, embeddings=None):
"""This function creates a list of retrievers based on the memory settings for a given vector name and returns a ContextualCompressionRetriever object. If no retrievers are created, it returns None. The function takes a vector name and an optional embeddings parameter."""
elif vs_str == 'alloydb':
from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore
from google.cloud.alloydb.connector import IPTypes
Expand Down Expand Up @@ -139,4 +145,4 @@ def pick_vectorstore(vs_str, vector_name, embeddings):
return vectorstore

else:
raise NotImplementedError(f'No llm implemented for {vs_str}')
raise NotImplementedError(f'No llm implemented for {vs_str}')

0 comments on commit 1b777bf

Please sign in to comment.