Skip to content

Commit

Permalink
Feature - Classifier (#695)
Browse files Browse the repository at this point in the history
* added classifier and sample metric

* added comments and updated example description

* removed non necessary object
  • Loading branch information
Montagon committed Mar 22, 2024
1 parent 2555301 commit b87bb47
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 0 deletions.
39 changes: 39 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Expand Up @@ -9,6 +9,7 @@ import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.prompt.Prompt
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
Expand All @@ -20,6 +21,10 @@ import kotlinx.serialization.serializer

sealed interface AI {

interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}

companion object {

fun <A : Any> chat(
Expand Down Expand Up @@ -65,6 +70,40 @@ sealed interface AI {
}
.invoke(prompt)

/**
* Classify a prompt using a given enum.
*
* @param input The input to the model.
* @param output The output to the model.
* @param context The context to the model.
* @param model The model to use.
* @param target The target type to return.
* @param api The chat API to use.
* @param conversation The conversation to use.
* @return The classified enum.
* @throws IllegalArgumentException If no enum values are found.
*/
@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> classify(
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
target: KType = typeOf<E>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): E where E : PromptClassifier, E : Enum<E> {
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
return invoke(
prompt = value.template(input, output, context),
model = model,
target = target,
api = api,
conversation = conversation
)
}

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
Expand Down
1 change: 1 addition & 0 deletions evaluator/build.gradle.kts
Expand Up @@ -18,6 +18,7 @@ java {
dependencies {
api(libs.kotlinx.serialization.json)
detektPlugins(project(":detekt-rules"))
implementation(projects.xefCore)
}

detekt {
Expand Down
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.evaluator.metrics

import com.xebia.functional.xef.AI

enum class AnswerAccuracy : AI.PromptClassifier {
yes,
no;

override fun template(input: String, output: String, context: String): String {
return """|
|You are an expert en evaluating whether the `output` is consistent with the given `input` and `context`.
| <input>
| $input
| </input>
| <output>
| $output
| </output>
| <context>
| $context
| </context>
|Return one of the following:
| - if the answer it's consistent: `yes`
| - if the answer it's not consistent: `no`
"""
.trimMargin()
}
}
1 change: 1 addition & 0 deletions examples/build.gradle.kts
Expand Up @@ -16,6 +16,7 @@ java {

dependencies {
implementation(projects.xefCore)
implementation(projects.xefEvaluator)
implementation(projects.xefFilesystem)
implementation(projects.xefPdf)
implementation(projects.xefSql)
Expand Down
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.dsl.classify

import com.xebia.functional.openai.models.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.evaluator.metrics.AnswerAccuracy

/**
* This is a simple example of how to use the `AI.classify` function to classify the accuracy of an
* answer. In this case, it's using the `AnswerAccuracy` enum class to classify if the answer is
* consistent or not.
*
* You can extend the `AI.PromptClassifier` interface to create your own classification. Override
* the `template` function to define the prompt to be used in the classification.
*/
suspend fun main() {
println(
AI.classify<AnswerAccuracy>("Do I love Xef?", "I love Xef", "The answer responds the question")
)
println(
AI.classify<AnswerAccuracy>(
input = "Do I love Xef?",
output = "I have three opened PRs",
context = "The answer responds the question",
model = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125
)
)
}

0 comments on commit b87bb47

Please sign in to comment.