diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt index 6ada5a106..0c6ae72f4 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt @@ -11,8 +11,10 @@ import com.xebia.functional.openai.models.ext.assistant.AssistantToolsCode import com.xebia.functional.openai.models.ext.assistant.AssistantToolsFunction import com.xebia.functional.openai.models.ext.assistant.AssistantToolsRetrieval import com.xebia.functional.xef.llm.fromEnvironment +import com.xebia.functional.xef.llm.models.functions.buildJsonSchema import io.ktor.util.logging.* import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive @@ -42,7 +44,7 @@ class Assistant( assistantsApi ) - suspend inline fun getToolRegistered(name: String, args: String): JsonElement = + suspend inline fun getToolRegistered(name: String, args: String): ToolOutput = try { val toolConfig = toolsConfig.firstOrNull { it.functionObject.name == name } @@ -51,20 +53,26 @@ class Assistant( val tool: Tool = toolConfig.tool as Tool + val schema = buildJsonSchema(toolSerializer.outputSerializer.descriptor) val output: Any? = tool(input) - ApiClient.JSON_DEFAULT.encodeToJsonElement( - toolSerializer.outputSerializer as KSerializer, - output - ) + val result = + ApiClient.JSON_DEFAULT.encodeToJsonElement( + toolSerializer.outputSerializer as KSerializer, + output + ) + ToolOutput(schema, result) } catch (e: Exception) { val message = "Error calling to tool registered $name: ${e.message}" val logger = KtorSimpleLogger("Functions") logger.error(message, e) - JsonObject(mapOf("error" to JsonPrimitive(message))) + val result = JsonObject(mapOf("error" to JsonPrimitive(message))) + ToolOutput(JsonObject(emptyMap()), result) } companion object { + @Serializable data class ToolOutput(val schema: JsonObject, val result: JsonElement) + suspend operator fun invoke( model: String, name: String? = null, diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt index 4cb07800c..3bb157905 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt @@ -16,7 +16,6 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.flow import kotlinx.serialization.encodeToString -import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject class AssistantThread( @@ -200,13 +199,12 @@ class AssistantThread( run.status == RunObject.Status.requires_action && run.requiredAction?.type == RunObjectRequiredAction.Type.submit_tool_outputs ) { - val results: Map = + val results: Map = calls .filter { it.function != null } .parMap { toolCall -> val function = toolCall.function!! - val result: JsonElement = - assistant.getToolRegistered(function.name, function.arguments) + val result = assistant.getToolRegistered(function.name, function.arguments) toolCall.id to result } .toMap() @@ -222,7 +220,11 @@ class AssistantThread( results.map { (toolCallId, result) -> SubmitToolOutputsRunRequestToolOutputsInner( toolCallId = toolCallId, - output = ApiClient.JSON_DEFAULT.encodeToString(result) + output = + ApiClient.JSON_DEFAULT.encodeToString( + Assistant.Companion.ToolOutput.serializer(), + result + ) ) } )