-
Notifications
You must be signed in to change notification settings - Fork 16
/
Embeddings.kt
40 lines (37 loc) · 1.32 KB
/
Embeddings.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package com.xebia.functional.xef.llm
import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import arrow.fx.coroutines.parMap
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequest
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.openai.models.ext.embedding.create.CreateEmbeddingRequestInput
suspend fun EmbeddingsApi.embedDocuments(
texts: List<String>,
chunkSize: Int = 400,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else
texts
.chunked(chunkSize)
.parMap {
createEmbedding(
CreateEmbeddingRequest(
model = embeddingRequestModel,
input = CreateEmbeddingRequestInput.StringArrayValue(it)
)
)
.body()
.data
}
.flatten()
suspend fun EmbeddingsApi.embedQuery(
text: String,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>
): List<Embedding> =
if (text.isNotEmpty())
embedDocuments(texts = listOf(text), embeddingRequestModel = embeddingRequestModel)
else emptyList()