From 53a267c96a5b1943f5d08ab4a83e55e437119df0 Mon Sep 17 00:00:00 2001 From: raulraja Date: Tue, 19 Mar 2024 08:24:51 +0100 Subject: [PATCH] Filter for thread messages fetching --- .../xef/llm/assistants/AssistantThread.kt | 52 +++++++++++++++---- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt index 3bb157905..5d1a7163f 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt @@ -2,6 +2,7 @@ package com.xebia.functional.xef.llm.assistants import arrow.fx.coroutines.parMap import com.xebia.functional.openai.apis.AssistantsApi +import com.xebia.functional.openai.apis.AssistantsApi.OrderListMessages import com.xebia.functional.openai.infrastructure.ApiClient import com.xebia.functional.openai.models.* import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsMessageCreationObject @@ -47,7 +48,26 @@ class AssistantThread( suspend fun getMessage(messageId: String): MessageObject = api.getMessage(threadId, messageId).body() - suspend fun listMessages(): List = api.listMessages(threadId).body().data + data class ThreadMessagesFilter( + val limit: Int? = 20, + val order: OrderListMessages? = OrderListMessages.desc, + val after: String? = null, + val before: String? = null + ) + + suspend fun listMessages( + filter: ThreadMessagesFilter = ThreadMessagesFilter() + ): List = + api + .listMessages( + threadId = threadId, + limit = filter.limit, + order = filter.order, + after = filter.after, + before = filter.before + ) + .body() + .data suspend fun createRun(request: CreateRunRequest): RunObject = api.createRun(threadId, request).body().addMetrics(metric) @@ -57,14 +77,21 @@ class AssistantThread( suspend fun createRun(assistant: Assistant): RunObject = createRun(CreateRunRequest(assistantId = assistant.assistantId)) - suspend fun run(assistant: Assistant): Flow { + suspend fun run( + assistant: Assistant, + filter: ThreadMessagesFilter = ThreadMessagesFilter() + ): Flow { val run = createRun(assistant) - return awaitRun(assistant, run.id) + return awaitRun(assistant, run.id, filter) } - suspend fun run(assistant: Assistant, request: CreateRunRequest): Flow { + suspend fun run( + assistant: Assistant, + request: CreateRunRequest, + filter: ThreadMessagesFilter = ThreadMessagesFilter() + ): Flow { val run = createRun(request) - return awaitRun(assistant, run.id) + return awaitRun(assistant, run.id, filter) } suspend fun cancelRun(runId: String): RunObject = api.cancelRun(threadId, runId).body() @@ -80,7 +107,11 @@ class AssistantThread( data class Step(val runStep: RunStepObject) : RunDelta() } - fun awaitRun(assistant: Assistant, runId: String): Flow = flow { + private fun awaitRun( + assistant: Assistant, + runId: String, + filter: ThreadMessagesFilter + ): Flow = flow { val stepCache = mutableSetOf() // CacheTool val messagesCache = mutableSetOf() val runCache = mutableSetOf() @@ -89,7 +120,7 @@ class AssistantThread( while (run.status != RunObject.Status.completed) { checkSteps(assistant = assistant, runId = runId, cache = stepCache) delay(500) // To avoid excessive calls to OpenAI - checkMessages(runId, cache = messagesCache) + checkMessages(runId = runId, cache = messagesCache, filter = filter) delay(500) // To avoid excessive calls to OpenAI run = checkRun(runId = runId, cache = runCache) } @@ -122,7 +153,7 @@ class AssistantThread( ) ) } finally { - checkMessages(runId, cache = messagesCache) + checkMessages(runId = runId, cache = messagesCache, filter = filter) } } @@ -140,11 +171,12 @@ class AssistantThread( private suspend fun FlowCollector.checkMessages( runId: String, - cache: MutableSet + cache: MutableSet, + filter: ThreadMessagesFilter, ) { metric.assistantCreatedMessage(runId) { val messages = mutableListOf() - val updatedAndNewMessages = listMessages().filterNot { it in cache } + val updatedAndNewMessages = listMessages(filter).filterNot { it in cache } updatedAndNewMessages.forEach { message -> val content = message.content.filterNot { it.text?.value?.isBlank() ?: true } if (content.isNotEmpty() && message !in cache) {