-
Notifications
You must be signed in to change notification settings - Fork 16
/
AI.kt
153 lines (142 loc) · 5.15 KB
/
AI.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package com.xebia.functional.xef
import ai.xef.openai.CustomModel
import ai.xef.openai.OpenAIModel
import com.xebia.functional.openai.apis.ChatApi
import com.xebia.functional.openai.apis.ImagesApi
import com.xebia.functional.openai.models.CreateChatCompletionRequestModel
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.prompt.Prompt
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.serializer
sealed interface AI {
interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}
companion object {
fun <A : Any> chat(
target: KType,
model: OpenAIModel<CreateChatCompletionRequestModel>,
api: ChatApi,
conversation: Conversation,
enumSerializer: ((case: String) -> A)?,
caseSerializers: List<KSerializer<A>>,
serializer: () -> KSerializer<A>,
): Chat<A> =
Chat(
target = target,
model = model,
api = api,
serializer = serializer,
conversation = conversation,
enumSerializer = enumSerializer,
caseSerializers = caseSerializers
)
fun images(
api: ImagesApi = fromEnvironment(::ImagesApi),
chatApi: ChatApi = fromEnvironment(::ChatApi)
): Images = Images(api, chatApi)
@PublishedApi
internal suspend inline fun <reified A : Any> invokeEnum(
prompt: Prompt<CreateChatCompletionRequestModel>,
target: KType = typeOf<A>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A =
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) },
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)
/**
* Classify a prompt using a given enum.
*
* @param input The input to the model.
* @param output The output to the model.
* @param context The context to the model.
* @param model The model to use.
* @param target The target type to return.
* @param api The chat API to use.
* @param conversation The conversation to use.
* @return The classified enum.
* @throws IllegalArgumentException If no enum values are found.
*/
@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> classify(
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
target: KType = typeOf<E>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): E where E : PromptClassifier, E : Enum<E> {
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
return invoke(
prompt = value.template(input, output, context),
model = model,
target = target,
api = api,
conversation = conversation
)
}
@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A = chat(Prompt(CustomModel(model.value), prompt), target, api, conversation)
@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: Prompt<CreateChatCompletionRequestModel>,
target: KType = typeOf<A>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A = chat(prompt, target, api, conversation)
@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
@AiDsl
suspend inline fun <reified A : Any> chat(
prompt: Prompt<CreateChatCompletionRequestModel>,
target: KType = typeOf<A>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A {
val kind =
(target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind
?: error("Cannot find SerialKind for $target")
return when (kind) {
SerialKind.ENUM -> invokeEnum<A>(prompt, target, api, conversation)
else -> {
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = null,
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)
}
}
}
}
}