diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt index 9ed01a68a..714311ead 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt @@ -9,6 +9,7 @@ import com.xebia.functional.xef.conversation.AiDsl import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.llm.fromEnvironment import com.xebia.functional.xef.prompt.Prompt +import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KClass import kotlin.reflect.KType import kotlin.reflect.typeOf @@ -20,6 +21,10 @@ import kotlinx.serialization.serializer sealed interface AI { + interface PromptClassifier { + fun template(input: String, output: String, context: String): String + } + companion object { fun chat( @@ -65,6 +70,40 @@ sealed interface AI { } .invoke(prompt) + /** + * Classify a prompt using a given enum. + * + * @param input The input to the model. + * @param output The output to the model. + * @param context The context to the model. + * @param model The model to use. + * @param target The target type to return. + * @param api The chat API to use. + * @param conversation The conversation to use. + * @return The classified enum. + * @throws IllegalArgumentException If no enum values are found. + */ + @AiDsl + @Throws(IllegalArgumentException::class, CancellationException::class) + suspend inline fun classify( + input: String, + output: String, + context: String, + model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview, + target: KType = typeOf(), + api: ChatApi = fromEnvironment(::ChatApi), + conversation: Conversation = Conversation() + ): E where E : PromptClassifier, E : Enum { + val value = enumValues().firstOrNull() ?: error("No enum values found") + return invoke( + prompt = value.template(input, output, context), + model = model, + target = target, + api = api, + conversation = conversation + ) + } + @AiDsl suspend inline operator fun invoke( prompt: String, diff --git a/evaluator/build.gradle.kts b/evaluator/build.gradle.kts index db9b3ab51..eb3355939 100644 --- a/evaluator/build.gradle.kts +++ b/evaluator/build.gradle.kts @@ -18,6 +18,7 @@ java { dependencies { api(libs.kotlinx.serialization.json) detektPlugins(project(":detekt-rules")) + implementation(projects.xefCore) } detekt { diff --git a/evaluator/src/main/kotlin/com/xebia/functional/xef/evaluator/metrics/AnswerAccuracy.kt b/evaluator/src/main/kotlin/com/xebia/functional/xef/evaluator/metrics/AnswerAccuracy.kt new file mode 100644 index 000000000..2fd0d7ca7 --- /dev/null +++ b/evaluator/src/main/kotlin/com/xebia/functional/xef/evaluator/metrics/AnswerAccuracy.kt @@ -0,0 +1,27 @@ +package com.xebia.functional.xef.evaluator.metrics + +import com.xebia.functional.xef.AI + +enum class AnswerAccuracy : AI.PromptClassifier { + yes, + no; + + override fun template(input: String, output: String, context: String): String { + return """| + |You are an expert en evaluating whether the `output` is consistent with the given `input` and `context`. + | + | $input + | + | + | $output + | + | + | $context + | + |Return one of the following: + | - if the answer it's consistent: `yes` + | - if the answer it's not consistent: `no` + """ + .trimMargin() + } +} diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index b0139755c..b29a285e2 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -16,6 +16,7 @@ java { dependencies { implementation(projects.xefCore) + implementation(projects.xefEvaluator) implementation(projects.xefFilesystem) implementation(projects.xefPdf) implementation(projects.xefSql) diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/classify/AnswerAccuracy.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/classify/AnswerAccuracy.kt new file mode 100644 index 000000000..040570954 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/classify/AnswerAccuracy.kt @@ -0,0 +1,27 @@ +package com.xebia.functional.xef.dsl.classify + +import com.xebia.functional.openai.models.CreateChatCompletionRequestModel +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.evaluator.metrics.AnswerAccuracy + +/** + * This is a simple example of how to use the `AI.classify` function to classify the accuracy of an + * answer. In this case, it's using the `AnswerAccuracy` enum class to classify if the answer is + * consistent or not. + * + * You can extend the `AI.PromptClassifier` interface to create your own classification. Override + * the `template` function to define the prompt to be used in the classification. + */ +suspend fun main() { + println( + AI.classify("Do I love Xef?", "I love Xef", "The answer responds the question") + ) + println( + AI.classify( + input = "Do I love Xef?", + output = "I have three opened PRs", + context = "The answer responds the question", + model = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125 + ) + ) +}