Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter for thread messages fetching #692

Merged
merged 2 commits into from Mar 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -2,6 +2,7 @@ package com.xebia.functional.xef.llm.assistants

import arrow.fx.coroutines.parMap
import com.xebia.functional.openai.apis.AssistantsApi
import com.xebia.functional.openai.apis.AssistantsApi.OrderListMessages
import com.xebia.functional.openai.infrastructure.ApiClient
import com.xebia.functional.openai.models.*
import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsMessageCreationObject
Expand Down Expand Up @@ -47,7 +48,26 @@ class AssistantThread(
suspend fun getMessage(messageId: String): MessageObject =
api.getMessage(threadId, messageId).body()

suspend fun listMessages(): List<MessageObject> = api.listMessages(threadId).body().data
data class ThreadMessagesFilter(
val limit: Int? = 20,
val order: OrderListMessages? = OrderListMessages.desc,
val after: String? = null,
val before: String? = null
)

suspend fun listMessages(
filter: ThreadMessagesFilter = ThreadMessagesFilter()
): List<MessageObject> =
api
.listMessages(
threadId = threadId,
limit = filter.limit,
order = filter.order,
after = filter.after,
before = filter.before
)
.body()
.data

suspend fun createRun(request: CreateRunRequest): RunObject =
api.createRun(threadId, request).body().addMetrics(metric)
Expand All @@ -57,14 +77,21 @@ class AssistantThread(
suspend fun createRun(assistant: Assistant): RunObject =
createRun(CreateRunRequest(assistantId = assistant.assistantId))

suspend fun run(assistant: Assistant): Flow<RunDelta> {
suspend fun run(
assistant: Assistant,
filter: ThreadMessagesFilter = ThreadMessagesFilter()
): Flow<RunDelta> {
val run = createRun(assistant)
return awaitRun(assistant, run.id)
return awaitRun(assistant, run.id, filter)
}

suspend fun run(assistant: Assistant, request: CreateRunRequest): Flow<RunDelta> {
suspend fun run(
assistant: Assistant,
request: CreateRunRequest,
filter: ThreadMessagesFilter = ThreadMessagesFilter()
): Flow<RunDelta> {
val run = createRun(request)
return awaitRun(assistant, run.id)
return awaitRun(assistant, run.id, filter)
}

suspend fun cancelRun(runId: String): RunObject = api.cancelRun(threadId, runId).body()
Expand All @@ -80,7 +107,11 @@ class AssistantThread(
data class Step(val runStep: RunStepObject) : RunDelta()
}

fun awaitRun(assistant: Assistant, runId: String): Flow<RunDelta> = flow {
private fun awaitRun(
assistant: Assistant,
runId: String,
filter: ThreadMessagesFilter
): Flow<RunDelta> = flow {
val stepCache = mutableSetOf<RunStepObject>() // CacheTool
val messagesCache = mutableSetOf<MessageObject>()
val runCache = mutableSetOf<RunObject>()
Expand All @@ -89,7 +120,7 @@ class AssistantThread(
while (run.status != RunObject.Status.completed) {
checkSteps(assistant = assistant, runId = runId, cache = stepCache)
delay(500) // To avoid excessive calls to OpenAI
checkMessages(runId, cache = messagesCache)
checkMessages(runId = runId, cache = messagesCache, filter = filter)
delay(500) // To avoid excessive calls to OpenAI
run = checkRun(runId = runId, cache = runCache)
}
Expand Down Expand Up @@ -122,7 +153,7 @@ class AssistantThread(
)
)
} finally {
checkMessages(runId, cache = messagesCache)
checkMessages(runId = runId, cache = messagesCache, filter = filter)
}
}

Expand All @@ -140,11 +171,12 @@ class AssistantThread(

private suspend fun FlowCollector<RunDelta>.checkMessages(
runId: String,
cache: MutableSet<MessageObject>
cache: MutableSet<MessageObject>,
filter: ThreadMessagesFilter,
) {
metric.assistantCreatedMessage(runId) {
val messages = mutableListOf<MessageObject>()
val updatedAndNewMessages = listMessages().filterNot { it in cache }
val updatedAndNewMessages = listMessages(filter).filterNot { it in cache }
updatedAndNewMessages.forEach { message ->
val content = message.content.filterNot { it.text?.value?.isBlank() ?: true }
if (content.isNotEmpty() && message !in cache) {
Expand Down