Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature - Classifier #695

Merged
merged 3 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Expand Up @@ -9,6 +9,7 @@ import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.prompt.Prompt
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
Expand All @@ -20,6 +21,10 @@ import kotlinx.serialization.serializer

sealed interface AI {

interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}

companion object {

fun <A : Any> chat(
Expand Down Expand Up @@ -65,6 +70,40 @@ sealed interface AI {
}
.invoke(prompt)

/**
* Classify a prompt using a given enum.
*
* @param input The input to the model.
* @param output The output to the model.
* @param context The context to the model.
* @param model The model to use.
* @param target The target type to return.
* @param api The chat API to use.
* @param conversation The conversation to use.
* @return The classified enum.
* @throws IllegalArgumentException If no enum values are found.
*/
@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> classify(
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
target: KType = typeOf<E>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): E where E : PromptClassifier, E : Enum<E> {
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
return invoke(
prompt = value.template(input, output, context),
model = model,
target = target,
api = api,
conversation = conversation
)
}

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
Expand Down
1 change: 1 addition & 0 deletions evaluator/build.gradle.kts
Expand Up @@ -18,6 +18,7 @@ java {
dependencies {
api(libs.kotlinx.serialization.json)
detektPlugins(project(":detekt-rules"))
implementation(projects.xefCore)
}

detekt {
Expand Down
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.evaluator.metrics

import com.xebia.functional.xef.AI

enum class AnswerAccuracy : AI.PromptClassifier {
yes,
no;

override fun template(input: String, output: String, context: String): String {
return """|
|You are an expert en evaluating whether the `output` is consistent with the given `input` and `context`.
| <input>
| $input
| </input>
| <output>
| $output
| </output>
| <context>
| $context
| </context>
|Return one of the following:
| - if the answer it's consistent: `yes`
| - if the answer it's not consistent: `no`
"""
.trimMargin()
}
}
1 change: 1 addition & 0 deletions examples/build.gradle.kts
Expand Up @@ -16,6 +16,7 @@ java {

dependencies {
implementation(projects.xefCore)
implementation(projects.xefEvaluator)
implementation(projects.xefFilesystem)
implementation(projects.xefPdf)
implementation(projects.xefSql)
Expand Down
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.dsl.classify

import com.xebia.functional.openai.models.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.evaluator.metrics.AnswerAccuracy

/**
* This is a simple example of how to use the `AI.classify` function to classify the accuracy of an
* answer. In this case, it's using the `AnswerAccuracy` enum class to classify if the answer is
* consistent or not.
*
* You can extend the `AI.PromptClassifier` interface to create your own classification. Override
* the `template` function to define the prompt to be used in the classification.
*/
suspend fun main() {
println(
AI.classify<AnswerAccuracy>("Do I love Xef?", "I love Xef", "The answer responds the question")
)
println(
AI.classify<AnswerAccuracy>(
input = "Do I love Xef?",
output = "I have three opened PRs",
context = "The answer responds the question",
model = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125
)
)
}