Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector store allowing to change embeddings models and similarity strategy #686

Merged
merged 4 commits into from Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.llm

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import arrow.fx.coroutines.parMap
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequest
Expand All @@ -9,7 +11,9 @@ import com.xebia.functional.openai.models.ext.embedding.create.CreateEmbeddingRe

suspend fun EmbeddingsApi.embedDocuments(
texts: List<String>,
chunkSize: Int = 400
chunkSize: Int = 400,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else
Expand All @@ -18,8 +22,7 @@ suspend fun EmbeddingsApi.embedDocuments(
.parMap {
createEmbedding(
CreateEmbeddingRequest(
model =
ai.xef.openai.StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002),
model = embeddingRequestModel,
input = CreateEmbeddingRequestInput.StringArrayValue(it)
)
)
Expand All @@ -28,5 +31,10 @@ suspend fun EmbeddingsApi.embedDocuments(
}
.flatten()

suspend fun EmbeddingsApi.embedQuery(text: String): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text)) else emptyList()
suspend fun EmbeddingsApi.embedQuery(
text: String,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>
): List<Embedding> =
if (text.isNotEmpty())
embedDocuments(texts = listOf(text), embeddingRequestModel = embeddingRequestModel)
else emptyList()
@@ -1,11 +1,13 @@
package com.xebia.functional.xef.store

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import arrow.atomic.Atomic
import arrow.atomic.AtomicInt
import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedDocuments
import com.xebia.functional.xef.llm.embedQuery
Expand All @@ -25,9 +27,16 @@ private data class State(
private typealias AtomicState = Atomic<State>

class LocalVectorStore
private constructor(private val embeddings: EmbeddingsApi, private val state: AtomicState) :
VectorStore {
constructor(embeddings: EmbeddingsApi) : this(embeddings, Atomic(State.empty()))
private constructor(
private val embeddings: EmbeddingsApi,
private val state: AtomicState,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>
) : VectorStore {
constructor(
embeddings: EmbeddingsApi,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
) : this(embeddings, Atomic(State.empty()), embeddingRequestModel)

override val indexValue: AtomicInt = AtomicInt(0)

Expand Down Expand Up @@ -68,15 +77,17 @@ private constructor(private val embeddings: EmbeddingsApi, private val state: At
}

override suspend fun addTexts(texts: List<String>) {
val embeddingsList = embeddings.embedDocuments(texts)
val embeddingsList =
embeddings.embedDocuments(texts, embeddingRequestModel = embeddingRequestModel)
state.getAndUpdate { prevState ->
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
}
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
val queryEmbedding = embeddings.embedQuery(query).firstOrNull()
val queryEmbedding =
embeddings.embedQuery(query, embeddingRequestModel = embeddingRequestModel).firstOrNull()
return queryEmbedding?.let { similaritySearchByVector(it, limit) }.orEmpty()
}

Expand Down
Expand Up @@ -4,6 +4,7 @@ import ai.xef.openai.OpenAIModel
import arrow.atomic.AtomicInt
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedQuery
import com.xebia.functional.xef.llm.models.modelType
Expand All @@ -24,6 +25,7 @@ open class Lucene(
private val writer: IndexWriter,
private val embeddings: EmbeddingsApi?,
private val similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN,
private val embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>
) : VectorStore, AutoCloseable {

override val indexValue: AtomicInt = AtomicInt(0)
Expand All @@ -47,12 +49,13 @@ open class Lucene(
}

override suspend fun <T> memories(
model: OpenAIModel<T>, conversationId: ConversationId, limitTokens: Int): List<Memory> =
model: OpenAIModel<T>, conversationId: ConversationId, limitTokens: Int
): List<Memory> =
getMemoryByConversationId(conversationId).reduceByLimitToken(model.modelType(), limitTokens).reversed()

override suspend fun addTexts(texts: List<String>) {
texts.forEach {
val embedding = embeddings?.embedQuery(it)
val embedding = embeddings?.embedQuery(text = it, embeddingRequestModel = embeddingAIModel)
val doc =
Document().apply {
add(TextField("contents", it, Field.Store.YES))
Expand Down Expand Up @@ -125,8 +128,9 @@ class DirectoryLucene(
private val directory: Directory,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
embeddings: EmbeddingsApi?,
embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>,
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
) : Lucene(IndexWriter(directory, writerConfig), embeddings, similarity) {
) : Lucene(IndexWriter(directory, writerConfig), embeddings, similarity, embeddingAIModel) {
override fun close() {
super.close()
directory.close()
Expand All @@ -138,17 +142,19 @@ fun InMemoryLucene(
path: Path,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
embeddings: EmbeddingsApi?,
embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>,
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
): DirectoryLucene = DirectoryLucene(MMapDirectory(path), writerConfig, embeddings, similarity)
): DirectoryLucene = DirectoryLucene(MMapDirectory(path), writerConfig, embeddings, embeddingAIModel, similarity)

@JvmOverloads
fun InMemoryLuceneBuilder(
path: Path,
useAIEmbeddings: Boolean = true,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>,
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
): (EmbeddingsApi) -> DirectoryLucene = { embeddings ->
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity)
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, embeddingAIModel, similarity)
}

fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.embedding.map { it.toFloat() } }.toFloatArray()
Expand Down
Expand Up @@ -4,6 +4,7 @@ import ai.xef.openai.OpenAIModel
import arrow.atomic.AtomicInt
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedDocuments
import com.xebia.functional.xef.llm.embedQuery
Expand All @@ -20,6 +21,7 @@ class PGVectorStore(
private val collectionName: String,
private val distanceStrategy: PGDistanceStrategy,
private val preDeleteCollection: Boolean,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>,
private val chunkSize: Int = 400
) : VectorStore {

Expand Down Expand Up @@ -83,7 +85,7 @@ class PGVectorStore(

override suspend fun addTexts(texts: List<String>): Unit =
dataSource.connection {
val embeddings = embeddings.embedDocuments(texts, chunkSize)
val embeddings = embeddings.embedDocuments(texts, chunkSize, embeddingRequestModel)
val collection = getCollection(collectionName)
texts.zip(embeddings) { text, embedding ->
val uuid = UUID.generateUUID()
Expand All @@ -105,7 +107,7 @@ class PGVectorStore(
if (!hasEmbeddings) return emptyList()

val embeddings =
embeddings.embedQuery(query).ifEmpty {
embeddings.embedQuery(query, embeddingRequestModel).ifEmpty {
throw IllegalStateException(
"Embedding for text: '$query', has not been properly generated"
)
Expand Down
Expand Up @@ -38,14 +38,17 @@ class PGVectorStoreSpec :
)
)

val embeddingsRequestModel = StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)

fun StringSpecScope.pg() =
PGVectorStore(
vectorSize = 3,
dataSource = dataSource,
embeddings = TestEmbeddings(coroutineContext),
collectionName = "test_collection",
distanceStrategy = PGDistanceStrategy.Euclidean,
preDeleteCollection = false
preDeleteCollection = false,
embeddingRequestModel = embeddingsRequestModel
)

beforeContainer {
Expand All @@ -56,7 +59,8 @@ class PGVectorStoreSpec :
embeddings = TestEmbeddings(coroutineContext),
collectionName = "test_collection",
distanceStrategy = PGDistanceStrategy.Euclidean,
preDeleteCollection = false
preDeleteCollection = false,
embeddingRequestModel = embeddingsRequestModel
)
postgresVector.initialDbSetup()
postgresVector.createCollection()
Expand Down
@@ -1,6 +1,9 @@
package com.xebia.functional.xef.server.services

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.llm.fromToken
import com.xebia.functional.xef.server.http.routes.Provider
Expand All @@ -21,6 +24,9 @@ class PostgresVectorStoreService(
private val vectorSize: Int,
private val preDeleteCollection: Boolean = false,
private val chunkSize: Int = 400,
private val distanceStrategy: PGDistanceStrategy = PGDistanceStrategy.Euclidean,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
) : VectorStoreService() {

fun addCollection() {
Expand All @@ -45,8 +51,9 @@ class PostgresVectorStoreService(
dataSource = dataSource,
embeddings = embeddingsApi,
collectionName = collectionName,
distanceStrategy = PGDistanceStrategy.Euclidean,
distanceStrategy = distanceStrategy,
preDeleteCollection = preDeleteCollection,
embeddingRequestModel = embeddingRequestModel,
chunkSize = chunkSize
)
}
Expand Down