Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it easier to manage server-level chat settings #729

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PublicConversation,
ReflectiveQuestion,
SearchModelConfig,
ServerChatSettings,
SpeechToTextModelOptions,
Subscription,
TextToImageModelConfig,
Expand Down Expand Up @@ -702,11 +703,36 @@ async def aget_conversation_config(user: KhojUser):

@staticmethod
def get_default_conversation_config():
return ChatModelOptions.objects.filter().first()
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.default_model is None:
return ChatModelOptions.objects.filter().first()
return server_chat_settings.default_model

@staticmethod
async def aget_default_conversation_config():
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("default_model", "default_model__openai_config")
.afirst()
)
if server_chat_settings is None or server_chat_settings.default_model is None:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.default_model

@staticmethod
async def aget_summarizer_conversation_config():
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related(
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
)
.afirst()
)
if server_chat_settings is None or (
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
):
return await ChatModelOptions.objects.filter().afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model

@staticmethod
def create_conversation_from_public_conversation(
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/database/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ProcessLock,
ReflectiveQuestion,
SearchModelConfig,
ServerChatSettings,
SpeechToTextModelOptions,
Subscription,
TextToImageModelConfig,
Expand Down Expand Up @@ -55,6 +56,7 @@ class KhojUserAdmin(UserAdmin):
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
admin.site.register(ServerChatSettings)


@admin.register(Agent)
Expand Down
46 changes: 46 additions & 0 deletions src/khoj/database/migrations/0042_serverchatsettings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Generated by Django 4.2.10 on 2024-04-29 11:04

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0041_merge_20240505_1234"),
]

operations = [
migrations.CreateModel(
name="ServerChatSettings",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"default_model",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="default_model",
to="database.chatmodeloptions",
),
),
(
"summarizer_model",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="summarizer_model",
to="database.chatmodeloptions",
),
),
],
options={
"abstract": False,
},
),
]
9 changes: 9 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ class GithubRepoConfig(BaseModel):
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")


class ServerChatSettings(BaseModel):
default_model = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model"
)
summarizer_model = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model"
)


class LocalOrgConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
Expand Down
12 changes: 9 additions & 3 deletions src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,17 @@ async def send_rate_limit_message(message: str):

if ConversationCommand.Webpage in conversation_commands:
try:
online_results = await read_webpages(defiltered_query, meta_log, location, send_status_update)
direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
webpages = []
for query in online_results:
for webpage in online_results[query]["webpages"]:
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}

for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])

await send_status_update(f"**📚 Read web pages**: {webpages}")
except ValueError as e:
logger.warning(
Expand Down
11 changes: 9 additions & 2 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,13 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
corpus=corpus.strip(),
)

summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()

with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information, prompts.system_prompt_extract_relevant_information
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
chat_model_option=summarizer_model,
)

return response.strip()
Expand Down Expand Up @@ -449,8 +453,11 @@ async def send_message_to_model_wrapper(
message: str,
system_message: str = "",
response_type: str = "text",
chat_model_option: ChatModelOptions = None,
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
)

if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
Expand Down