Skip to content

Commit

Permalink
Assistant metrics (#674)
Browse files Browse the repository at this point in the history
* First approach for assistant metrics

* Support for messages

* Adding usage on metrics

* OpenTelemetric comment
  • Loading branch information
javipacheco committed Mar 5, 2024
1 parent e42d443 commit cd2f923
Show file tree
Hide file tree
Showing 8 changed files with 432 additions and 50 deletions.
@@ -1,7 +1,9 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.openai.models.CreateChatCompletionResponse
import com.xebia.functional.openai.models.RunObject
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.metrics.Metric
import com.xebia.functional.xef.prompt.Prompt

suspend fun CreateChatCompletionResponse.addMetrics(
Expand Down Expand Up @@ -41,3 +43,8 @@ suspend fun <T> Prompt<T>.addMetrics(conversation: Conversation) {
if (functions.isNotEmpty())
conversation.metric.parameter("openai.chat_completion.functions", functions.map { it.name })
}

suspend fun RunObject.addMetrics(metric: Metric): RunObject {
metric.assistantCreateRun(this)
return this
}
Expand Up @@ -8,7 +8,6 @@ import com.xebia.functional.openai.models.CreateAssistantRequest
import com.xebia.functional.openai.models.ModifyAssistantRequest
import com.xebia.functional.openai.models.ext.assistant.AssistantTools
import com.xebia.functional.xef.llm.fromEnvironment
import io.ktor.client.statement.*
import io.ktor.util.logging.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonElement
Expand Down
Expand Up @@ -7,7 +7,9 @@ import com.xebia.functional.openai.models.*
import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsMessageCreationObject
import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsToolCallsObject
import com.xebia.functional.openai.models.ext.assistant.RunStepObjectStepDetails
import com.xebia.functional.xef.llm.addMetrics
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.metrics.Metric
import kotlin.jvm.JvmName
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
Expand All @@ -19,6 +21,7 @@ import kotlinx.serialization.json.JsonObject

class AssistantThread(
val threadId: String,
val metric: Metric = Metric.EMPTY,
private val api: AssistantsApi = fromEnvironment(::AssistantsApi)
) {

Expand All @@ -28,24 +31,16 @@ class AssistantThread(
AssistantThread(api.modifyThread(threadId, request).body().id)

suspend fun createMessage(message: MessageWithFiles): MessageObject =
api
.createMessage(
threadId,
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = message.content,
fileIds = message.fileIds
)
createMessage(
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = message.content,
fileIds = message.fileIds
)
.body()
)

suspend fun createMessage(content: String): MessageObject =
api
.createMessage(
threadId,
CreateMessageRequest(role = CreateMessageRequest.Role.user, content = content)
)
.body()
createMessage(CreateMessageRequest(role = CreateMessageRequest.Role.user, content = content))

suspend fun createMessage(request: CreateMessageRequest): MessageObject =
api.createMessage(threadId, request).body()
Expand All @@ -56,12 +51,12 @@ class AssistantThread(
suspend fun listMessages(): List<MessageObject> = api.listMessages(threadId).body().data

suspend fun createRun(request: CreateRunRequest): RunObject =
api.createRun(threadId, request).body()
api.createRun(threadId, request).body().addMetrics(metric)

suspend fun getRun(runId: String): RunObject = api.getRun(threadId, runId).body()

suspend fun createRun(assistant: Assistant): RunObject =
api.createRun(threadId, CreateRunRequest(assistantId = assistant.assistantId)).body()
createRun(CreateRunRequest(assistantId = assistant.assistantId))

suspend fun run(assistant: Assistant): Flow<RunDelta> {
val run = createRun(assistant)
Expand Down Expand Up @@ -90,7 +85,7 @@ class AssistantThread(
while (run.status != RunObject.Status.completed) {
checkSteps(assistant = assistant, runId = runId, cache = stepCache)
delay(500) // To avoid excessive calls to OpenAI
checkMessages(cache = messagesCache)
checkMessages(runId, cache = messagesCache)
delay(500) // To avoid excessive calls to OpenAI
run = checkRun(runId = runId, cache = runCache)
}
Expand Down Expand Up @@ -123,31 +118,38 @@ class AssistantThread(
)
)
} finally {
checkMessages(cache = messagesCache)
checkMessages(runId, cache = messagesCache)
}
}

private suspend fun FlowCollector<RunDelta>.checkRun(
runId: String,
cache: MutableSet<RunObject>
): RunObject {
val run = getRun(runId)
val run = metric.assistantCreateRun(runId) { getRun(runId) }
if (run !in cache) {
cache.add(run)
emit(RunDelta.Run(run))
}
return run
}

private suspend fun FlowCollector<RunDelta>.checkMessages(cache: MutableSet<MessageObject>) {
val messages = listMessages()
val updatedAndNewMessages = messages.filterNot { it in cache }
updatedAndNewMessages.forEach { message ->
val content = message.content.filterNot { it.text?.value?.isBlank() ?: true }
if (content.isNotEmpty() && message !in cache) {
cache.add(message)
emit(RunDelta.ReceivedMessage(message))
private suspend fun FlowCollector<RunDelta>.checkMessages(
runId: String,
cache: MutableSet<MessageObject>
) {
metric.assistantCreatedMessage(runId) {
val messages = mutableListOf<MessageObject>()
val updatedAndNewMessages = listMessages().filterNot { it in cache }
updatedAndNewMessages.forEach { message ->
val content = message.content.filterNot { it.text?.value?.isBlank() ?: true }
if (content.isNotEmpty() && message !in cache) {
cache.add(message)
messages.add(message)
emit(RunDelta.ReceivedMessage(message))
}
}
messages
}
}

Expand All @@ -163,7 +165,9 @@ class AssistantThread(
runId: String,
cache: MutableSet<RunStepObject>
) {
val steps = runSteps(runId)

val steps = runSteps(runId).map { metric.assistantCreateRunStep(runId) { it } }

steps.forEach { step ->
val calls = step.stepDetails.toolCalls()
// .filter {
Expand Down Expand Up @@ -201,20 +205,25 @@ class AssistantThread(
toolCall.id to result
}
.toMap()
api.submitToolOuputsToRun(
threadId = threadId,
runId = runId,
submitToolOutputsRunRequest =
SubmitToolOutputsRunRequest(
toolOutputs =
results.map { (toolCallId, result) ->
SubmitToolOutputsRunRequestToolOutputsInner(
toolCallId = toolCallId,
output = ApiClient.JSON_DEFAULT.encodeToString(result)
)
}

metric.assistantToolOutputsRun(runId) {
api
.submitToolOuputsToRun(
threadId = threadId,
runId = runId,
submitToolOutputsRunRequest =
SubmitToolOutputsRunRequest(
toolOutputs =
results.map { (toolCallId, result) ->
SubmitToolOutputsRunRequestToolOutputsInner(
toolCallId = toolCallId,
output = ApiClient.JSON_DEFAULT.encodeToString(result)
)
}
)
)
)
.body()
}
}
}
}
Expand All @@ -225,6 +234,7 @@ class AssistantThread(
suspend operator fun invoke(
messages: List<MessageWithFiles>,
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
api: AssistantsApi = fromEnvironment(::AssistantsApi)
): AssistantThread =
AssistantThread(
Expand All @@ -243,13 +253,15 @@ class AssistantThread(
)
.body()
.id,
metric,
api
)

@JvmName("createWithMessages")
suspend operator fun invoke(
messages: List<String>,
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
api: AssistantsApi = fromEnvironment(::AssistantsApi)
): AssistantThread =
AssistantThread(
Expand All @@ -264,25 +276,33 @@ class AssistantThread(
)
.body()
.id,
metric,
api
)

@JvmName("createWithRequests")
suspend operator fun invoke(
messages: List<CreateMessageRequest> = emptyList(),
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
api: AssistantsApi = fromEnvironment(::AssistantsApi)
): AssistantThread =
AssistantThread(api.createThread(CreateThreadRequest(messages, metadata)).body().id, api)
AssistantThread(
api.createThread(CreateThreadRequest(messages, metadata)).body().id,
metric,
api
)

suspend operator fun invoke(
request: CreateThreadRequest,
metric: Metric = Metric.EMPTY,
api: AssistantsApi = fromEnvironment(::AssistantsApi)
): AssistantThread = AssistantThread(api.createThread(request).body().id, api)
): AssistantThread = AssistantThread(api.createThread(request).body().id, metric, api)

suspend operator fun invoke(
request: CreateThreadAndRunRequest,
metric: Metric = Metric.EMPTY,
api: AssistantsApi = fromEnvironment(::AssistantsApi)
): AssistantThread = AssistantThread(api.createThreadAndRun(request).body().id, api)
): AssistantThread = AssistantThread(api.createThreadAndRun(request).body().id, metric, api)
}
}
@@ -1,6 +1,9 @@
package com.xebia.functional.xef.metrics

import arrow.atomic.AtomicInt
import com.xebia.functional.openai.models.MessageObject
import com.xebia.functional.openai.models.RunObject
import com.xebia.functional.openai.models.RunStepObject
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.kotlinlogging.KotlinLogging
import io.github.oshai.kotlinlogging.Level
Expand Down Expand Up @@ -39,6 +42,70 @@ class LogsMetric(private val level: Level = Level.INFO) : Metric {
return output
}

override suspend fun assistantCreateRun(runObject: RunObject) {
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- AssistantId: ${runObject.assistantId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- ThreadId: ${runObject.threadId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- RunId: ${runObject.id}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${runObject.status.value}"
}
}

override suspend fun assistantCreateRun(
runId: String,
block: suspend Metric.() -> RunObject
): RunObject {
val output = block()
assistantCreateRun(output)
return output
}

override suspend fun assistantCreatedMessage(
runId: String,
block: suspend Metric.() -> List<MessageObject>
): List<MessageObject> {
val output = block()
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- Size: ${output.size}"
}
return output
}

override suspend fun assistantCreateRunStep(
runId: String,
block: suspend Metric.() -> RunStepObject
): RunStepObject {
val output = block()
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- AssistantId: ${output.assistantId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- ThreadId: ${output.threadId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- RunId: ${output.runId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${output.status.value}"
}
return output
}

override suspend fun assistantToolOutputsRun(
runId: String,
block: suspend Metric.() -> RunObject
): RunObject {
val output = block()
assistantCreateRun(output)
return output
}

override suspend fun event(message: String) {
logger.at(level) { this.message = "${writeIndent(numberOfBlocks.get())}|-- $message" }
}
Expand Down

0 comments on commit cd2f923

Please sign in to comment.