-
Notifications
You must be signed in to change notification settings - Fork 16
/
LocalVectorStore.kt
112 lines (100 loc) · 3.96 KB
/
LocalVectorStore.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package com.xebia.functional.xef.store
import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import arrow.atomic.Atomic
import arrow.atomic.AtomicInt
import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedDocuments
import com.xebia.functional.xef.llm.embedQuery
import com.xebia.functional.xef.llm.models.modelType
import kotlin.math.sqrt
private data class State(
val orderedMemories: Map<ConversationId, List<Memory>>,
val documents: List<String>,
val precomputedEmbeddings: Map<String, Embedding>
) {
companion object {
fun empty(): State = State(emptyMap(), emptyList(), emptyMap())
}
}
private typealias AtomicState = Atomic<State>
class LocalVectorStore
private constructor(
private val embeddings: EmbeddingsApi,
private val state: AtomicState,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>
) : VectorStore {
constructor(
embeddings: EmbeddingsApi,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
) : this(embeddings, Atomic(State.empty()), embeddingRequestModel)
override val indexValue: AtomicInt = AtomicInt(0)
override fun updateIndexByConversationId(conversationId: ConversationId) {
state.get().orderedMemories[conversationId]?.let { memories ->
memories.maxByOrNull { it.index }?.let { lastMemory -> indexValue.set(lastMemory.index) }
}
}
override suspend fun addMemories(memories: List<Memory>) {
state.update { prevState ->
prevState.copy(
orderedMemories =
memories
.groupBy { it.conversationId }
.let { memories ->
(prevState.orderedMemories.keys + memories.keys).associateWith { key ->
val l1 = prevState.orderedMemories[key] ?: emptyList()
val l2 = memories[key] ?: emptyList()
l1 + l2
}
}
)
}
}
override suspend fun <T> memories(
model: OpenAIModel<T>,
conversationId: ConversationId,
limitTokens: Int
): List<Memory> {
val memories = state.get().orderedMemories[conversationId]
return memories
.orEmpty()
.sortedByDescending { it.index }
.reduceByLimitToken(model.modelType(), limitTokens)
.reversed()
}
override suspend fun addTexts(texts: List<String>) {
val embeddingsList =
embeddings.embedDocuments(texts, embeddingRequestModel = embeddingRequestModel)
state.getAndUpdate { prevState ->
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
}
}
override suspend fun similaritySearch(query: String, limit: Int): List<String> {
val queryEmbedding =
embeddings.embedQuery(query, embeddingRequestModel = embeddingRequestModel).firstOrNull()
return queryEmbedding?.let { similaritySearchByVector(it, limit) }.orEmpty()
}
override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
val state0 = state.get()
return state0.documents
.asSequence()
.mapNotNull { doc -> state0.precomputedEmbeddings[doc]?.let { doc to it } }
.map { (doc, e) -> doc to embedding.cosineSimilarity(e) }
.sortedByDescending { (_, similarity) -> similarity }
.take(limit)
.map { (document, _) -> document }
.toList()
}
private fun Embedding.cosineSimilarity(other: Embedding): Double {
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b).toDouble() }
val magnitudeA = sqrt(this.embedding.sumOf { (it * it).toDouble() })
val magnitudeB = sqrt(other.embedding.sumOf { (it * it).toDouble() })
return dotProduct / (magnitudeA * magnitudeB)
}
}