Skip to content

Commit

Permalink
implement #18;
Browse files Browse the repository at this point in the history
  • Loading branch information
ellipsis-dev[bot] committed Mar 25, 2024
1 parent da8f8a8 commit 0132756
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 9 deletions.
20 changes: 20 additions & 0 deletions sunholo/components/README.md
@@ -0,0 +1,20 @@
# Sunholo Components

This directory contains the core components of the Sunholo project. Each Python file serves a specific purpose in the overall functionality of the project.

## __init__.py
This file is used to initialize the Sunholo components package. It imports the necessary functions from the other Python files in this directory.

## llm.py
This file contains the logic for the Language Model.

## prompt.py
This file contains the logic for generating prompts for the Language Model.

## retriever.py
This file contains the logic for retrieving information from the Language Model.

## vectorstore.py
This file contains the logic for storing and retrieving vectors.

For a detailed description of the functions within each Python file, please refer to the respective file. Each function's documentation includes information about what the function does, its parameters, and its return value.
20 changes: 19 additions & 1 deletion sunholo/components/llm.py
Expand Up @@ -16,6 +16,9 @@

logging = setup_logging()

"""
This function picks a language model based on the `vector_name` parameter and returns the chosen language model, its embeddings, and its chat model.
"""
def pick_llm(vector_name):
logging.debug('Picking llm')

Expand Down Expand Up @@ -47,6 +50,9 @@ def pick_llm(vector_name):

return llm, embeddings, llm_chat

"""
This function checks if the language model specified by the `vector_name` parameter supports streaming and returns a boolean value.
"""
def pick_streaming(vector_name):

llm_str = load_config_key("llm", vector_name, filename = "config/llm_config.yaml")
Expand All @@ -57,6 +63,9 @@ def pick_streaming(vector_name):
return False


"""
This function gets a language model based on the `vector_name` and `model` parameters and the `config_file` configuration file.
"""
def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"):
llm_str = load_config_key("llm", vector_name, filename=config_file)
model_lookup_filepath = get_module_filepath("lookup/model_lookup.yaml")
Expand Down Expand Up @@ -106,6 +115,9 @@ def get_llm(vector_name, model=None, config_file="config/llm_config.yaml"):
if llm_str is None:
raise NotImplementedError(f'No llm implemented for {llm_str}')

"""
This function gets a chat model based on the `vector_name` and `model` parameters and the `config_file` configuration file.
"""
def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"):
llm_str = load_config_key("llm", vector_name, filename=config_file)
if not model:
Expand Down Expand Up @@ -150,13 +162,19 @@ def get_llm_chat(vector_name, model=None, config_file="config/llm_config.yaml"):
if llm_str is None:
raise NotImplementedError(f'No llm implemented for {llm_str}')

"""
This function gets the embeddings for the language model specified by the `vector_name` parameter.
"""
def get_embeddings(vector_name):
llm_str = load_config_key("llm", vector_name, filename="config/llm_config.yaml")

return pick_embedding(llm_str)



"""
This function picks the embeddings based on the `llm_str` parameter.
"""
def pick_embedding(llm_str: str):
# get embedding directly from llm_str
# Configure embeddings based on llm_str
Expand All @@ -175,4 +193,4 @@ def pick_embedding(llm_str: str):
return GoogleGenerativeAIEmbeddings(model="models/embedding-001") #TODO add embedding type

if llm_str is None:
raise NotImplementedError(f'No embeddings implemented for {llm_str}')
raise NotImplementedError(f'No embeddings implemented for {llm_str}')
15 changes: 9 additions & 6 deletions sunholo/components/prompt.py
Expand Up @@ -23,7 +23,7 @@


def pick_prompt(vector_name, chat_history=[]):
"""Pick a custom prompt"""
"""This function picks a custom prompt based on the `vector_name` and `chat_history` parameters and returns the chosen prompt."""
logging.debug('Picking prompt')

prompt_str = load_config_key("prompt", vector_name, filename = "config/llm_config.yaml")
Expand Down Expand Up @@ -69,13 +69,14 @@ def pick_prompt(vector_name, chat_history=[]):
else:
follow_up += ".\n"

memory_str = "\n## Your Memory (ignore if not relevant to question)\n{context}\n"
memory_str = "\n## Your Memory (ignore if not relevant to question)
{context}\n"

current_conversation = ""
if chat_summary != "":
current_conversation =f"## Current Conversation\n{chat_summary}\n"
current_conversation = current_conversation.replace("{","{{").replace("}","}}") #escape {} characters

buddy_question = ""
my_q = "## Current Question\n{question}\n"
if agent_buddy:
Expand All @@ -92,28 +93,30 @@ def pick_prompt(vector_name, chat_history=[]):
return QA_PROMPT

def pick_chat_buddy(vector_name):
"""This function picks a chat buddy based on the `vector_name` parameter and returns the chosen chat buddy and its description."""
chat_buddy = load_config_key("chat_buddy", vector_name, filename = "config/llm_config.yaml")
if chat_buddy is not None:
logging.info(f"Got chat buddy {chat_buddy} for {vector_name}")
buddy_description = load_config_key("chat_buddy_description", vector_name)
return chat_buddy, buddy_description
return None, None


def pick_agent(vector_name):
"""This function determines if an agent is needed based on the `vector_name` parameter and returns a boolean value."""
agent_str = load_config_key("agent", vector_name, filename = "config/llm_config.yaml")
if agent_str == "yes":
return True

return False

def pick_shared_vectorstore(vector_name, embeddings):
"""This function picks a shared vector store based on the `vector_name` and `embeddings` parameters and returns the chosen vector store."""
shared_vectorstore = load_config_key("shared_vectorstore", vector_name, filename = "config/llm_config.yaml")
vectorstore = pick_vectorstore(shared_vectorstore, embeddings)
return vectorstore


def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -> str:
"""This function gets the chat history based on the `inputs`, `vector_name`, `last_chars`, and `summary_chars` parameters and returns the chat history as a string."""
from langchain.schema import Document
from ..summarise import summarise_docs

Expand Down Expand Up @@ -148,4 +151,4 @@ def get_chat_history(inputs, vector_name, last_chars=1000, summary_chars=1500) -
summary = text_sum[:summary_chars]

# Concatenate the summary and the last `last_chars` characters of the chat history
return summary + "\n### Recent Chat History\n..." + recent_history
return summary + "\n### Recent Chat History\n..." + recent_history
7 changes: 5 additions & 2 deletions sunholo/components/retriever.py
Expand Up @@ -26,7 +26,9 @@

logging = setup_logging()


def load_memories(vector_name):
"""This function loads the memories for a given `vector_name` and returns the loaded memories."""
memories = load_config_key("memory", vector_name, filename="config/llm_config.yaml")
logging.info(f"Found memory settings for {vector_name}: {memories}")
if len(memories) == 0:
Expand All @@ -35,8 +37,9 @@ def load_memories(vector_name):

return memories

def pick_retriever(vector_name, embeddings=None):

def pick_retriever(vector_name, embeddings=None):
"""This function picks a retriever based on the `vector_name` and `embeddings` parameters and returns the chosen retriever."""
memories = load_memories(vector_name)

retriever_list = []
Expand Down Expand Up @@ -78,4 +81,4 @@ def pick_retriever(vector_name, embeddings=None):
base_compressor=pipeline, base_retriever=lotr,
k=3)

return retriever
return retriever

0 comments on commit 0132756

Please sign in to comment.