Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool output response schema #685

Merged
merged 2 commits into from Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -11,8 +11,10 @@ import com.xebia.functional.openai.models.ext.assistant.AssistantToolsCode
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsFunction
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsRetrieval
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import io.ktor.util.logging.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
Expand Down Expand Up @@ -42,7 +44,7 @@ class Assistant(
assistantsApi
)

suspend inline fun getToolRegistered(name: String, args: String): JsonElement =
suspend inline fun getToolRegistered(name: String, args: String): ToolOutput =
try {
val toolConfig = toolsConfig.firstOrNull { it.functionObject.name == name }

Expand All @@ -51,20 +53,26 @@ class Assistant(

val tool: Tool<Any?, Any?> = toolConfig.tool as Tool<Any?, Any?>

val schema = buildJsonSchema(toolSerializer.outputSerializer.descriptor)
val output: Any? = tool(input)
ApiClient.JSON_DEFAULT.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
output
)
val result =
ApiClient.JSON_DEFAULT.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
output
)
ToolOutput(schema, result)
} catch (e: Exception) {
val message = "Error calling to tool registered $name: ${e.message}"
val logger = KtorSimpleLogger("Functions")
logger.error(message, e)
JsonObject(mapOf("error" to JsonPrimitive(message)))
val result = JsonObject(mapOf("error" to JsonPrimitive(message)))
ToolOutput(JsonObject(emptyMap()), result)
}

companion object {

@Serializable data class ToolOutput(val schema: JsonObject, val result: JsonElement)

suspend operator fun invoke(
model: String,
name: String? = null,
Expand Down
Expand Up @@ -16,7 +16,6 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject

class AssistantThread(
Expand Down Expand Up @@ -200,13 +199,12 @@ class AssistantThread(
run.status == RunObject.Status.requires_action &&
run.requiredAction?.type == RunObjectRequiredAction.Type.submit_tool_outputs
) {
val results: Map<String, JsonElement> =
val results: Map<String, Assistant.Companion.ToolOutput> =
calls
.filter { it.function != null }
.parMap { toolCall ->
val function = toolCall.function!!
val result: JsonElement =
assistant.getToolRegistered(function.name, function.arguments)
val result = assistant.getToolRegistered(function.name, function.arguments)
toolCall.id to result
}
.toMap()
Expand All @@ -222,7 +220,11 @@ class AssistantThread(
results.map { (toolCallId, result) ->
SubmitToolOutputsRunRequestToolOutputsInner(
toolCallId = toolCallId,
output = ApiClient.JSON_DEFAULT.encodeToString(result)
output =
ApiClient.JSON_DEFAULT.encodeToString(
Assistant.Companion.ToolOutput.serializer(),
result
)
)
}
)
Expand Down